akhaliq3 commited on
Commit
a8c8bc6
1 Parent(s): 15b8733
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE.txt +97 -0
  2. app.py +169 -0
  3. avg_spectra.py +276 -0
  4. calc_metrics.py +188 -0
  5. dataset_tool.py +455 -0
  6. dnnlib/__init__.py +9 -0
  7. dnnlib/util.py +491 -0
  8. environment.yml +24 -0
  9. gen_images.py +144 -0
  10. gen_video.py +178 -0
  11. gui_utils/__init__.py +9 -0
  12. gui_utils/gl_utils.py +374 -0
  13. gui_utils/glfw_window.py +229 -0
  14. gui_utils/imgui_utils.py +169 -0
  15. gui_utils/imgui_window.py +103 -0
  16. gui_utils/text_utils.py +123 -0
  17. legacy.py +323 -0
  18. metrics/__init__.py +9 -0
  19. metrics/equivariance.py +267 -0
  20. metrics/frechet_inception_distance.py +41 -0
  21. metrics/inception_score.py +38 -0
  22. metrics/kernel_inception_distance.py +46 -0
  23. metrics/metric_main.py +153 -0
  24. metrics/metric_utils.py +279 -0
  25. metrics/perceptual_path_length.py +125 -0
  26. metrics/precision_recall.py +62 -0
  27. torch_utils/__init__.py +9 -0
  28. torch_utils/custom_ops.py +157 -0
  29. torch_utils/misc.py +266 -0
  30. torch_utils/ops/__init__.py +9 -0
  31. torch_utils/ops/bias_act.cpp +99 -0
  32. torch_utils/ops/bias_act.cu +173 -0
  33. torch_utils/ops/bias_act.h +38 -0
  34. torch_utils/ops/bias_act.py +209 -0
  35. torch_utils/ops/conv2d_gradfix.py +198 -0
  36. torch_utils/ops/conv2d_resample.py +143 -0
  37. torch_utils/ops/filtered_lrelu.cpp +300 -0
  38. torch_utils/ops/filtered_lrelu.cu +1284 -0
  39. torch_utils/ops/filtered_lrelu.h +90 -0
  40. torch_utils/ops/filtered_lrelu.py +274 -0
  41. torch_utils/ops/filtered_lrelu_ns.cu +27 -0
  42. torch_utils/ops/filtered_lrelu_rd.cu +27 -0
  43. torch_utils/ops/filtered_lrelu_wr.cu +27 -0
  44. torch_utils/ops/fma.py +60 -0
  45. torch_utils/ops/grid_sample_gradfix.py +77 -0
  46. torch_utils/ops/upfirdn2d.cpp +107 -0
  47. torch_utils/ops/upfirdn2d.cu +384 -0
  48. torch_utils/ops/upfirdn2d.h +59 -0
  49. torch_utils/ops/upfirdn2d.py +389 -0
  50. torch_utils/persistence.py +251 -0
LICENSE.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for StyleGAN3
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html")
3
+ os.system("git clone https://github.com/NVlabs/stylegan3")
4
+ os.system("git clone https://github.com/openai/CLIP")
5
+ os.system("pip install -e ./CLIP")
6
+ os.system("pip install einops ninja scipy numpy Pillow tqdm")
7
+ import sys
8
+ sys.path.append('./CLIP')
9
+ sys.path.append('./stylegan3')
10
+ import io
11
+ import os, time
12
+ import pickle
13
+ import shutil
14
+ import numpy as np
15
+ from PIL import Image
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import requests
19
+ import torchvision.transforms as transforms
20
+ import torchvision.transforms.functional as TF
21
+ import clip
22
+ from tqdm.notebook import tqdm
23
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
24
+ from einops import rearrange
25
+ device = torch.device('cuda:0')
26
+ def fetch(url_or_path):
27
+ if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
28
+ r = requests.get(url_or_path)
29
+ r.raise_for_status()
30
+ fd = io.BytesIO()
31
+ fd.write(r.content)
32
+ fd.seek(0)
33
+ return fd
34
+ return open(url_or_path, 'rb')
35
+ def fetch_model(url_or_path):
36
+ basename = os.path.basename(url_or_path)
37
+ if os.path.exists(basename):
38
+ return basename
39
+ else:
40
+ os.system("wget -c '{url_or_path}'")
41
+ return basename
42
+ def norm1(prompt):
43
+ "Normalize to the unit sphere."
44
+ return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
45
+ def spherical_dist_loss(x, y):
46
+ x = F.normalize(x, dim=-1)
47
+ y = F.normalize(y, dim=-1)
48
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
49
+ class MakeCutouts(torch.nn.Module):
50
+ def __init__(self, cut_size, cutn, cut_pow=1.):
51
+ super().__init__()
52
+ self.cut_size = cut_size
53
+ self.cutn = cutn
54
+ self.cut_pow = cut_pow
55
+ def forward(self, input):
56
+ sideY, sideX = input.shape[2:4]
57
+ max_size = min(sideX, sideY)
58
+ min_size = min(sideX, sideY, self.cut_size)
59
+ cutouts = []
60
+ for _ in range(self.cutn):
61
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
62
+ offsetx = torch.randint(0, sideX - size + 1, ())
63
+ offsety = torch.randint(0, sideY - size + 1, ())
64
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
65
+ cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
66
+ return torch.cat(cutouts)
67
+ make_cutouts = MakeCutouts(224, 32, 0.5)
68
+ def embed_image(image):
69
+ n = image.shape[0]
70
+ cutouts = make_cutouts(image)
71
+ embeds = clip_model.embed_cutout(cutouts)
72
+ embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
73
+ return embeds
74
+ def embed_url(url):
75
+ image = Image.open(fetch(url)).convert('RGB')
76
+ return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
77
+ class CLIP(object):
78
+ def __init__(self):
79
+ clip_model = "ViT-B/32"
80
+ self.model, _ = clip.load(clip_model)
81
+ self.model = self.model.requires_grad_(False)
82
+ self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
83
+ std=[0.26862954, 0.26130258, 0.27577711])
84
+ @torch.no_grad()
85
+ def embed_text(self, prompt):
86
+ "Normalized clip text embedding."
87
+ return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
88
+ def embed_cutout(self, image):
89
+ "Normalized clip image embedding."
90
+ return norm1(self.model.encode_image(self.normalize(image)))
91
+
92
+ clip_model = CLIP()
93
+ # Load stylegan model
94
+ base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/"
95
+ model_name = "stylegan3-t-ffhqu-1024x1024.pkl"
96
+ #model_name = "stylegan3-r-metfacesu-1024x1024.pkl"
97
+ #model_name = "stylegan3-t-afhqv2-512x512.pkl"
98
+ network_url = base_url + model_name
99
+ os.system("wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl")
100
+ with open('stylegan3-t-ffhqu-1024x1024.pkl', 'rb') as fp:
101
+ G = pickle.load(fp)['G_ema'].to(device)
102
+ zs = torch.randn([10000, G.mapping.z_dim], device=device)
103
+ w_stds = G.mapping(zs, None).std(0)
104
+
105
+
106
+ def inference(text):
107
+ target = clip_model.embed_text(text)
108
+ steps = 600
109
+ seed = 2
110
+ tf = Compose([
111
+ Resize(224),
112
+ lambda x: torch.clamp((x+1)/2,min=0,max=1),
113
+ ])
114
+ torch.manual_seed(seed)
115
+ timestring = time.strftime('%Y%m%d%H%M%S')
116
+ with torch.no_grad():
117
+ qs = []
118
+ losses = []
119
+ for _ in range(8):
120
+ q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
121
+ images = G.synthesis(q * w_stds + G.mapping.w_avg)
122
+ embeds = embed_image(images.add(1).div(2))
123
+ loss = spherical_dist_loss(embeds, target).mean(0)
124
+ i = torch.argmin(loss)
125
+ qs.append(q[i])
126
+ losses.append(loss[i])
127
+ qs = torch.stack(qs)
128
+ losses = torch.stack(losses)
129
+ print(losses)
130
+ print(losses.shape, qs.shape)
131
+ i = torch.argmin(losses)
132
+ q = qs[i].unsqueeze(0)
133
+ q.requires_grad_()
134
+ q_ema = q
135
+ opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
136
+ loop = tqdm(range(steps))
137
+ for i in loop:
138
+ opt.zero_grad()
139
+ w = q * w_stds
140
+ image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
141
+ embed = embed_image(image.add(1).div(2))
142
+ loss = spherical_dist_loss(embed, target).mean()
143
+ loss.backward()
144
+ opt.step()
145
+ loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
146
+ q_ema = q_ema * 0.9 + q * 0.1
147
+ image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
148
+ if i % 10 == 0:
149
+ display(TF.to_pil_image(tf(image)[0]))
150
+ pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
151
+ #os.makedirs(f'samples/{timestring}', exist_ok=True)
152
+ #pil_image.save(f'samples/{timestring}/{i:04}.jpg')
153
+ return pil_image
154
+
155
+
156
+ title = "StyleGAN+CLIP_with_Latent_Bootstraping"
157
+ description = "Gradio demo for StyleGAN+CLIP_with_Latent_Bootstraping. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
158
+ article = "<p style='text-align: center'>colab by https://twitter.com/EricHallahan <a href='https://colab.research.google.com/drive/1br7GP_D6XCgulxPTAFhwGaV-ijFe084X' target='_blank'>Colab</a></p>"
159
+ examples = [['elon musk']]
160
+ gr.Interface(
161
+ inference,
162
+ "text",
163
+ gr.outputs.Image(type="pil", label="Output"),
164
+ title=title,
165
+ description=description,
166
+ article=article,
167
+ enable_queue=True,
168
+ examples=examples
169
+ ).launch(debug=True)
avg_spectra.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Compare average power spectra between real and generated images,
10
+ or between multiple generators."""
11
+
12
+ import os
13
+ import numpy as np
14
+ import torch
15
+ import torch.fft
16
+ import scipy.ndimage
17
+ import matplotlib.pyplot as plt
18
+ import click
19
+ import tqdm
20
+ import dnnlib
21
+
22
+ import legacy
23
+ from training import dataset
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Setup an iterator for streaming images, in uint8 NCHW format, based on the
27
+ # respective command line options.
28
+
29
+ def stream_source_images(source, num, seed, device, data_loader_kwargs=None): # => num_images, image_size, image_iter
30
+ ext = source.split('.')[-1].lower()
31
+ if data_loader_kwargs is None:
32
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
33
+
34
+ if ext == 'pkl':
35
+ if num is None:
36
+ raise click.ClickException('--num is required when --source points to network pickle')
37
+ with dnnlib.util.open_url(source) as f:
38
+ G = legacy.load_network_pkl(f)['G_ema'].to(device)
39
+ def generate_image(seed):
40
+ rnd = np.random.RandomState(seed)
41
+ z = torch.from_numpy(rnd.randn(1, G.z_dim)).to(device)
42
+ c = torch.zeros([1, G.c_dim], device=device)
43
+ if G.c_dim > 0:
44
+ c[:, rnd.randint(G.c_dim)] = 1
45
+ return (G(z=z, c=c) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
46
+ _ = generate_image(seed) # warm up
47
+ image_iter = (generate_image(seed + idx) for idx in range(num))
48
+ return num, G.img_resolution, image_iter
49
+
50
+ elif ext == 'zip' or os.path.isdir(source):
51
+ dataset_obj = dataset.ImageFolderDataset(path=source, max_size=num, random_seed=seed)
52
+ if num is not None and num != len(dataset_obj):
53
+ raise click.ClickException(f'--source contains fewer than {num} images')
54
+ data_loader = torch.utils.data.DataLoader(dataset_obj, batch_size=1, **data_loader_kwargs)
55
+ image_iter = (image.to(device) for image, _label in data_loader)
56
+ return len(dataset_obj), dataset_obj.resolution, image_iter
57
+
58
+ else:
59
+ raise click.ClickException('--source must point to network pickle, dataset zip, or directory')
60
+
61
+ #----------------------------------------------------------------------------
62
+ # Load average power spectrum from the specified .npz file and construct
63
+ # the corresponding heatmap for visualization.
64
+
65
+ def construct_heatmap(npz_file, smooth):
66
+ npz_data = np.load(npz_file)
67
+ spectrum = npz_data['spectrum']
68
+ image_size = npz_data['image_size']
69
+ hmap = np.log10(spectrum) * 10 # dB
70
+ hmap = np.fft.fftshift(hmap)
71
+ hmap = np.concatenate([hmap, hmap[:1, :]], axis=0)
72
+ hmap = np.concatenate([hmap, hmap[:, :1]], axis=1)
73
+ if smooth > 0:
74
+ sigma = spectrum.shape[0] / image_size * smooth
75
+ hmap = scipy.ndimage.gaussian_filter(hmap, sigma=sigma, mode='nearest')
76
+ return hmap, image_size
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ @click.group()
81
+ def main():
82
+ """Compare average power spectra between real and generated images,
83
+ or between multiple generators.
84
+
85
+ Example:
86
+
87
+ \b
88
+ # Calculate dataset mean and std, needed in subsequent steps.
89
+ python avg_spectra.py stats --source=~/datasets/ffhq-1024x1024.zip
90
+
91
+ \b
92
+ # Calculate average spectrum for the training data.
93
+ python avg_spectra.py calc --source=~/datasets/ffhq-1024x1024.zip \\
94
+ --dest=tmp/training-data.npz --mean=112.684 --std=69.509
95
+
96
+ \b
97
+ # Calculate average spectrum for a pre-trained generator.
98
+ python avg_spectra.py calc \\
99
+ --source=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl \\
100
+ --dest=tmp/stylegan3-r.npz --mean=112.684 --std=69.509 --num=70000
101
+
102
+ \b
103
+ # Display results.
104
+ python avg_spectra.py heatmap tmp/training-data.npz
105
+ python avg_spectra.py heatmap tmp/stylegan3-r.npz
106
+ python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz
107
+
108
+ \b
109
+ # Save as PNG.
110
+ python avg_spectra.py heatmap tmp/training-data.npz --save=tmp/training-data.png --dpi=300
111
+ python avg_spectra.py heatmap tmp/stylegan3-r.npz --save=tmp/stylegan3-r.png --dpi=300
112
+ python avg_spectra.py slices tmp/training-data.npz tmp/stylegan3-r.npz --save=tmp/slices.png --dpi=300
113
+ """
114
+
115
+ #----------------------------------------------------------------------------
116
+
117
+ @main.command()
118
+ @click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
119
+ @click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
120
+ @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
121
+ def stats(source, num, seed, device=torch.device('cuda')):
122
+ """Calculate dataset mean and standard deviation needed by 'calc'."""
123
+ torch.multiprocessing.set_start_method('spawn')
124
+ num_images, _image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
125
+
126
+ # Accumulate moments.
127
+ moments = torch.zeros([3], dtype=torch.float64, device=device)
128
+ for image in tqdm.tqdm(image_iter, total=num_images):
129
+ image = image.to(torch.float64)
130
+ moments += torch.stack([torch.ones_like(image).sum(), image.sum(), image.square().sum()])
131
+ moments = moments / moments[0]
132
+
133
+ # Compute mean and standard deviation.
134
+ mean = moments[1]
135
+ std = (moments[2] - moments[1].square()).sqrt()
136
+ print(f'--mean={mean:g} --std={std:g}')
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ @main.command()
141
+ @click.option('--source', help='Network pkl, dataset zip, or directory', metavar='[PKL|ZIP|DIR]', required=True)
142
+ @click.option('--dest', help='Where to store the result', metavar='NPZ', required=True)
143
+ @click.option('--mean', help='Dataset mean for whitening', metavar='FLOAT', type=float, required=True)
144
+ @click.option('--std', help='Dataset standard deviation for whitening', metavar='FLOAT', type=click.FloatRange(min=0), required=True)
145
+ @click.option('--num', help='Number of images to process [default: all]', metavar='INT', type=click.IntRange(min=1))
146
+ @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
147
+ @click.option('--beta', help='Shape parameter for the Kaiser window', metavar='FLOAT', type=click.FloatRange(min=0), default=8, show_default=True)
148
+ @click.option('--interp', help='Frequency-domain interpolation factor', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True)
149
+ def calc(source, dest, mean, std, num, seed, beta, interp, device=torch.device('cuda')):
150
+ """Calculate average power spectrum and store it in .npz file."""
151
+ torch.multiprocessing.set_start_method('spawn')
152
+ num_images, image_size, image_iter = stream_source_images(source=source, num=num, seed=seed, device=device)
153
+ spectrum_size = image_size * interp
154
+ padding = spectrum_size - image_size
155
+
156
+ # Setup window function.
157
+ window = torch.kaiser_window(image_size, periodic=False, beta=beta, device=device)
158
+ window *= window.square().sum().rsqrt()
159
+ window = window.ger(window).unsqueeze(0).unsqueeze(1)
160
+
161
+ # Accumulate power spectrum.
162
+ spectrum = torch.zeros([spectrum_size, spectrum_size], dtype=torch.float64, device=device)
163
+ for image in tqdm.tqdm(image_iter, total=num_images):
164
+ image = (image.to(torch.float64) - mean) / std
165
+ image = torch.nn.functional.pad(image * window, [0, padding, 0, padding])
166
+ spectrum += torch.fft.fftn(image, dim=[2,3]).abs().square().mean(dim=[0,1])
167
+ spectrum /= num_images
168
+
169
+ # Save result.
170
+ if os.path.dirname(dest):
171
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
172
+ np.savez(dest, spectrum=spectrum.cpu().numpy(), image_size=image_size)
173
+
174
+ #----------------------------------------------------------------------------
175
+
176
+ @main.command()
177
+ @click.argument('npz-file', nargs=1)
178
+ @click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
179
+ @click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
180
+ @click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=1.25, show_default=True)
181
+ def heatmap(npz_file, save, smooth, dpi):
182
+ """Visualize 2D heatmap based on the given .npz file."""
183
+ hmap, image_size = construct_heatmap(npz_file=npz_file, smooth=smooth)
184
+
185
+ # Setup plot.
186
+ plt.figure(figsize=[6, 4.8], dpi=dpi, tight_layout=True)
187
+ freqs = np.linspace(-0.5, 0.5, num=hmap.shape[0], endpoint=True) * image_size
188
+ ticks = np.linspace(freqs[0], freqs[-1], num=5, endpoint=True)
189
+ levels = np.linspace(-40, 20, num=13, endpoint=True)
190
+
191
+ # Draw heatmap.
192
+ plt.xlim(ticks[0], ticks[-1])
193
+ plt.ylim(ticks[0], ticks[-1])
194
+ plt.xticks(ticks)
195
+ plt.yticks(ticks)
196
+ plt.contourf(freqs, freqs, hmap, levels=levels, extend='both', cmap='Blues')
197
+ plt.gca().set_aspect('equal')
198
+ plt.colorbar(ticks=levels)
199
+ plt.contour(freqs, freqs, hmap, levels=levels, extend='both', linestyles='solid', linewidths=1, colors='midnightblue', alpha=0.2)
200
+
201
+ # Display or save.
202
+ if save is None:
203
+ plt.show()
204
+ else:
205
+ if os.path.dirname(save):
206
+ os.makedirs(os.path.dirname(save), exist_ok=True)
207
+ plt.savefig(save)
208
+
209
+ #----------------------------------------------------------------------------
210
+
211
+ @main.command()
212
+ @click.argument('npz-files', nargs=-1, required=True)
213
+ @click.option('--save', help='Save the plot and exit', metavar='[PNG|PDF|...]')
214
+ @click.option('--dpi', help='Figure resolution', metavar='FLOAT', type=click.FloatRange(min=1), default=100, show_default=True)
215
+ @click.option('--smooth', help='Amount of smoothing', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True)
216
+ def slices(npz_files, save, dpi, smooth):
217
+ """Visualize 1D slices based on the given .npz files."""
218
+ cases = [dnnlib.EasyDict(npz_file=npz_file) for npz_file in npz_files]
219
+ for c in cases:
220
+ c.hmap, c.image_size = construct_heatmap(npz_file=c.npz_file, smooth=smooth)
221
+ c.label = os.path.splitext(os.path.basename(c.npz_file))[0]
222
+
223
+ # Check consistency.
224
+ image_size = cases[0].image_size
225
+ hmap_size = cases[0].hmap.shape[0]
226
+ if any(c.image_size != image_size or c.hmap.shape[0] != hmap_size for c in cases):
227
+ raise click.ClickException('All .npz must have the same resolution')
228
+
229
+ # Setup plot.
230
+ plt.figure(figsize=[12, 4.6], dpi=dpi, tight_layout=True)
231
+ hmap_center = hmap_size // 2
232
+ hmap_range = np.arange(hmap_center, hmap_size)
233
+ freqs0 = np.linspace(0, image_size / 2, num=(hmap_size // 2 + 1), endpoint=True)
234
+ freqs45 = np.linspace(0, image_size / np.sqrt(2), num=(hmap_size // 2 + 1), endpoint=True)
235
+ xticks0 = np.linspace(freqs0[0], freqs0[-1], num=9, endpoint=True)
236
+ xticks45 = np.round(np.linspace(freqs45[0], freqs45[-1], num=9, endpoint=True))
237
+ yticks = np.linspace(-50, 30, num=9, endpoint=True)
238
+
239
+ # Draw 0 degree slice.
240
+ plt.subplot(1, 2, 1)
241
+ plt.title('0\u00b0 slice')
242
+ plt.xlim(xticks0[0], xticks0[-1])
243
+ plt.ylim(yticks[0], yticks[-1])
244
+ plt.xticks(xticks0)
245
+ plt.yticks(yticks)
246
+ for c in cases:
247
+ plt.plot(freqs0, c.hmap[hmap_center, hmap_range], label=c.label)
248
+ plt.grid()
249
+ plt.legend(loc='upper right')
250
+
251
+ # Draw 45 degree slice.
252
+ plt.subplot(1, 2, 2)
253
+ plt.title('45\u00b0 slice')
254
+ plt.xlim(xticks45[0], xticks45[-1])
255
+ plt.ylim(yticks[0], yticks[-1])
256
+ plt.xticks(xticks45)
257
+ plt.yticks(yticks)
258
+ for c in cases:
259
+ plt.plot(freqs45, c.hmap[hmap_range, hmap_range], label=c.label)
260
+ plt.grid()
261
+ plt.legend(loc='upper right')
262
+
263
+ # Display or save.
264
+ if save is None:
265
+ plt.show()
266
+ else:
267
+ if os.path.dirname(save):
268
+ os.makedirs(os.path.dirname(save), exist_ok=True)
269
+ plt.savefig(save)
270
+
271
+ #----------------------------------------------------------------------------
272
+
273
+ if __name__ == "__main__":
274
+ main() # pylint: disable=no-value-for-parameter
275
+
276
+ #----------------------------------------------------------------------------
calc_metrics.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Calculate quality metrics for previous training run or pretrained network pickle."""
10
+
11
+ import os
12
+ import click
13
+ import json
14
+ import tempfile
15
+ import copy
16
+ import torch
17
+
18
+ import dnnlib
19
+ import legacy
20
+ from metrics import metric_main
21
+ from metrics import metric_utils
22
+ from torch_utils import training_stats
23
+ from torch_utils import custom_ops
24
+ from torch_utils import misc
25
+ from torch_utils.ops import conv2d_gradfix
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def subprocess_fn(rank, args, temp_dir):
30
+ dnnlib.util.Logger(should_flush=True)
31
+
32
+ # Init torch.distributed.
33
+ if args.num_gpus > 1:
34
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
35
+ if os.name == 'nt':
36
+ init_method = 'file:///' + init_file.replace('\\', '/')
37
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
38
+ else:
39
+ init_method = f'file://{init_file}'
40
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
41
+
42
+ # Init torch_utils.
43
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
44
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
45
+ if rank != 0 or not args.verbose:
46
+ custom_ops.verbosity = 'none'
47
+
48
+ # Configure torch.
49
+ device = torch.device('cuda', rank)
50
+ torch.backends.cuda.matmul.allow_tf32 = False
51
+ torch.backends.cudnn.allow_tf32 = False
52
+ conv2d_gradfix.enabled = True
53
+
54
+ # Print network summary.
55
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
56
+ if rank == 0 and args.verbose:
57
+ z = torch.empty([1, G.z_dim], device=device)
58
+ c = torch.empty([1, G.c_dim], device=device)
59
+ misc.print_module_summary(G, [z, c])
60
+
61
+ # Calculate each metric.
62
+ for metric in args.metrics:
63
+ if rank == 0 and args.verbose:
64
+ print(f'Calculating {metric}...')
65
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
66
+ result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
67
+ num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
68
+ if rank == 0:
69
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
70
+ if rank == 0 and args.verbose:
71
+ print()
72
+
73
+ # Done.
74
+ if rank == 0 and args.verbose:
75
+ print('Exiting...')
76
+
77
+ #----------------------------------------------------------------------------
78
+
79
+ def parse_comma_separated_list(s):
80
+ if isinstance(s, list):
81
+ return s
82
+ if s is None or s.lower() == 'none' or s == '':
83
+ return []
84
+ return s.split(',')
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ @click.command()
89
+ @click.pass_context
90
+ @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
91
+ @click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k_full', show_default=True)
92
+ @click.option('--data', help='Dataset to evaluate against [default: look up]', metavar='[ZIP|DIR]')
93
+ @click.option('--mirror', help='Enable dataset x-flips [default: look up]', type=bool, metavar='BOOL')
94
+ @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
95
+ @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
96
+
97
+ def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
98
+ """Calculate quality metrics for previous training run or pretrained network pickle.
99
+
100
+ Examples:
101
+
102
+ \b
103
+ # Previous training run: look up options automatically, save result to JSONL file.
104
+ python calc_metrics.py --metrics=eqt50k_int,eqr50k \\
105
+ --network=~/training-runs/00000-stylegan3-r-mydataset/network-snapshot-000000.pkl
106
+
107
+ \b
108
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
109
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq-1024x1024.zip --mirror=1 \\
110
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl
111
+
112
+ \b
113
+ Recommended metrics:
114
+ fid50k_full Frechet inception distance against the full dataset.
115
+ kid50k_full Kernel inception distance against the full dataset.
116
+ pr50k3_full Precision and recall againt the full dataset.
117
+ ppl2_wend Perceptual path length in W, endpoints, full image.
118
+ eqt50k_int Equivariance w.r.t. integer translation (EQ-T).
119
+ eqt50k_frac Equivariance w.r.t. fractional translation (EQ-T_frac).
120
+ eqr50k Equivariance w.r.t. rotation (EQ-R).
121
+
122
+ \b
123
+ Legacy metrics:
124
+ fid50k Frechet inception distance against 50k real images.
125
+ kid50k Kernel inception distance against 50k real images.
126
+ pr50k3 Precision and recall against 50k real images.
127
+ is50k Inception score for CIFAR-10.
128
+ """
129
+ dnnlib.util.Logger(should_flush=True)
130
+
131
+ # Validate arguments.
132
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
133
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
134
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
135
+ if not args.num_gpus >= 1:
136
+ ctx.fail('--gpus must be at least 1')
137
+
138
+ # Load network.
139
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
140
+ ctx.fail('--network must point to a file or URL')
141
+ if args.verbose:
142
+ print(f'Loading network from "{network_pkl}"...')
143
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
144
+ network_dict = legacy.load_network_pkl(f)
145
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
146
+
147
+ # Initialize dataset options.
148
+ if data is not None:
149
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
150
+ elif network_dict['training_set_kwargs'] is not None:
151
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
152
+ else:
153
+ ctx.fail('Could not look up dataset options; please specify --data')
154
+
155
+ # Finalize dataset options.
156
+ args.dataset_kwargs.resolution = args.G.img_resolution
157
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
158
+ if mirror is not None:
159
+ args.dataset_kwargs.xflip = mirror
160
+
161
+ # Print dataset options.
162
+ if args.verbose:
163
+ print('Dataset options:')
164
+ print(json.dumps(args.dataset_kwargs, indent=2))
165
+
166
+ # Locate run dir.
167
+ args.run_dir = None
168
+ if os.path.isfile(network_pkl):
169
+ pkl_dir = os.path.dirname(network_pkl)
170
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
171
+ args.run_dir = pkl_dir
172
+
173
+ # Launch processes.
174
+ if args.verbose:
175
+ print('Launching processes...')
176
+ torch.multiprocessing.set_start_method('spawn')
177
+ with tempfile.TemporaryDirectory() as temp_dir:
178
+ if args.num_gpus == 1:
179
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
180
+ else:
181
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
182
+
183
+ #----------------------------------------------------------------------------
184
+
185
+ if __name__ == "__main__":
186
+ calc_metrics() # pylint: disable=no-value-for-parameter
187
+
188
+ #----------------------------------------------------------------------------
dataset_tool.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Tool for creating ZIP/PNG based datasets."""
10
+
11
+ import functools
12
+ import gzip
13
+ import io
14
+ import json
15
+ import os
16
+ import pickle
17
+ import re
18
+ import sys
19
+ import tarfile
20
+ import zipfile
21
+ from pathlib import Path
22
+ from typing import Callable, Optional, Tuple, Union
23
+
24
+ import click
25
+ import numpy as np
26
+ import PIL.Image
27
+ from tqdm import tqdm
28
+
29
+ #----------------------------------------------------------------------------
30
+
31
+ def error(msg):
32
+ print('Error: ' + msg)
33
+ sys.exit(1)
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ def parse_tuple(s: str) -> Tuple[int, int]:
38
+ '''Parse a 'M,N' or 'MxN' integer tuple.
39
+
40
+ Example:
41
+ '4x2' returns (4,2)
42
+ '0,1' returns (0,1)
43
+ '''
44
+ if m := re.match(r'^(\d+)[x,](\d+)$', s):
45
+ return (int(m.group(1)), int(m.group(2)))
46
+ raise ValueError(f'cannot parse tuple {s}')
47
+
48
+ #----------------------------------------------------------------------------
49
+
50
+ def maybe_min(a: int, b: Optional[int]) -> int:
51
+ if b is not None:
52
+ return min(a, b)
53
+ return a
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def file_ext(name: Union[str, Path]) -> str:
58
+ return str(name).split('.')[-1]
59
+
60
+ #----------------------------------------------------------------------------
61
+
62
+ def is_image_ext(fname: Union[str, Path]) -> bool:
63
+ ext = file_ext(fname).lower()
64
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ def open_image_folder(source_dir, *, max_images: Optional[int]):
69
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
70
+
71
+ # Load labels.
72
+ labels = {}
73
+ meta_fname = os.path.join(source_dir, 'dataset.json')
74
+ if os.path.isfile(meta_fname):
75
+ with open(meta_fname, 'r') as file:
76
+ labels = json.load(file)['labels']
77
+ if labels is not None:
78
+ labels = { x[0]: x[1] for x in labels }
79
+ else:
80
+ labels = {}
81
+
82
+ max_idx = maybe_min(len(input_images), max_images)
83
+
84
+ def iterate_images():
85
+ for idx, fname in enumerate(input_images):
86
+ arch_fname = os.path.relpath(fname, source_dir)
87
+ arch_fname = arch_fname.replace('\\', '/')
88
+ img = np.array(PIL.Image.open(fname))
89
+ yield dict(img=img, label=labels.get(arch_fname))
90
+ if idx >= max_idx-1:
91
+ break
92
+ return max_idx, iterate_images()
93
+
94
+ #----------------------------------------------------------------------------
95
+
96
+ def open_image_zip(source, *, max_images: Optional[int]):
97
+ with zipfile.ZipFile(source, mode='r') as z:
98
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
99
+
100
+ # Load labels.
101
+ labels = {}
102
+ if 'dataset.json' in z.namelist():
103
+ with z.open('dataset.json', 'r') as file:
104
+ labels = json.load(file)['labels']
105
+ if labels is not None:
106
+ labels = { x[0]: x[1] for x in labels }
107
+ else:
108
+ labels = {}
109
+
110
+ max_idx = maybe_min(len(input_images), max_images)
111
+
112
+ def iterate_images():
113
+ with zipfile.ZipFile(source, mode='r') as z:
114
+ for idx, fname in enumerate(input_images):
115
+ with z.open(fname, 'r') as file:
116
+ img = PIL.Image.open(file) # type: ignore
117
+ img = np.array(img)
118
+ yield dict(img=img, label=labels.get(fname))
119
+ if idx >= max_idx-1:
120
+ break
121
+ return max_idx, iterate_images()
122
+
123
+ #----------------------------------------------------------------------------
124
+
125
+ def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
126
+ import cv2 # pip install opencv-python # pylint: disable=import-error
127
+ import lmdb # pip install lmdb # pylint: disable=import-error
128
+
129
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
130
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
131
+
132
+ def iterate_images():
133
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
134
+ for idx, (_key, value) in enumerate(txn.cursor()):
135
+ try:
136
+ try:
137
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
138
+ if img is None:
139
+ raise IOError('cv2.imdecode failed')
140
+ img = img[:, :, ::-1] # BGR => RGB
141
+ except IOError:
142
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
143
+ yield dict(img=img, label=None)
144
+ if idx >= max_idx-1:
145
+ break
146
+ except:
147
+ print(sys.exc_info()[1])
148
+
149
+ return max_idx, iterate_images()
150
+
151
+ #----------------------------------------------------------------------------
152
+
153
+ def open_cifar10(tarball: str, *, max_images: Optional[int]):
154
+ images = []
155
+ labels = []
156
+
157
+ with tarfile.open(tarball, 'r:gz') as tar:
158
+ for batch in range(1, 6):
159
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
160
+ with tar.extractfile(member) as file:
161
+ data = pickle.load(file, encoding='latin1')
162
+ images.append(data['data'].reshape(-1, 3, 32, 32))
163
+ labels.append(data['labels'])
164
+
165
+ images = np.concatenate(images)
166
+ labels = np.concatenate(labels)
167
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
168
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
169
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
170
+ assert np.min(images) == 0 and np.max(images) == 255
171
+ assert np.min(labels) == 0 and np.max(labels) == 9
172
+
173
+ max_idx = maybe_min(len(images), max_images)
174
+
175
+ def iterate_images():
176
+ for idx, img in enumerate(images):
177
+ yield dict(img=img, label=int(labels[idx]))
178
+ if idx >= max_idx-1:
179
+ break
180
+
181
+ return max_idx, iterate_images()
182
+
183
+ #----------------------------------------------------------------------------
184
+
185
+ def open_mnist(images_gz: str, *, max_images: Optional[int]):
186
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
187
+ assert labels_gz != images_gz
188
+ images = []
189
+ labels = []
190
+
191
+ with gzip.open(images_gz, 'rb') as f:
192
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
193
+ with gzip.open(labels_gz, 'rb') as f:
194
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
195
+
196
+ images = images.reshape(-1, 28, 28)
197
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
198
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
199
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
200
+ assert np.min(images) == 0 and np.max(images) == 255
201
+ assert np.min(labels) == 0 and np.max(labels) == 9
202
+
203
+ max_idx = maybe_min(len(images), max_images)
204
+
205
+ def iterate_images():
206
+ for idx, img in enumerate(images):
207
+ yield dict(img=img, label=int(labels[idx]))
208
+ if idx >= max_idx-1:
209
+ break
210
+
211
+ return max_idx, iterate_images()
212
+
213
+ #----------------------------------------------------------------------------
214
+
215
+ def make_transform(
216
+ transform: Optional[str],
217
+ output_width: Optional[int],
218
+ output_height: Optional[int]
219
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
220
+ def scale(width, height, img):
221
+ w = img.shape[1]
222
+ h = img.shape[0]
223
+ if width == w and height == h:
224
+ return img
225
+ img = PIL.Image.fromarray(img)
226
+ ww = width if width is not None else w
227
+ hh = height if height is not None else h
228
+ img = img.resize((ww, hh), PIL.Image.LANCZOS)
229
+ return np.array(img)
230
+
231
+ def center_crop(width, height, img):
232
+ crop = np.min(img.shape[:2])
233
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
234
+ img = PIL.Image.fromarray(img, 'RGB')
235
+ img = img.resize((width, height), PIL.Image.LANCZOS)
236
+ return np.array(img)
237
+
238
+ def center_crop_wide(width, height, img):
239
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
240
+ if img.shape[1] < width or ch < height:
241
+ return None
242
+
243
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
244
+ img = PIL.Image.fromarray(img, 'RGB')
245
+ img = img.resize((width, height), PIL.Image.LANCZOS)
246
+ img = np.array(img)
247
+
248
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
249
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
250
+ return canvas
251
+
252
+ if transform is None:
253
+ return functools.partial(scale, output_width, output_height)
254
+ if transform == 'center-crop':
255
+ if (output_width is None) or (output_height is None):
256
+ error ('must specify --resolution=WxH when using ' + transform + 'transform')
257
+ return functools.partial(center_crop, output_width, output_height)
258
+ if transform == 'center-crop-wide':
259
+ if (output_width is None) or (output_height is None):
260
+ error ('must specify --resolution=WxH when using ' + transform + ' transform')
261
+ return functools.partial(center_crop_wide, output_width, output_height)
262
+ assert False, 'unknown transform'
263
+
264
+ #----------------------------------------------------------------------------
265
+
266
+ def open_dataset(source, *, max_images: Optional[int]):
267
+ if os.path.isdir(source):
268
+ if source.rstrip('/').endswith('_lmdb'):
269
+ return open_lmdb(source, max_images=max_images)
270
+ else:
271
+ return open_image_folder(source, max_images=max_images)
272
+ elif os.path.isfile(source):
273
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
274
+ return open_cifar10(source, max_images=max_images)
275
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
276
+ return open_mnist(source, max_images=max_images)
277
+ elif file_ext(source) == 'zip':
278
+ return open_image_zip(source, max_images=max_images)
279
+ else:
280
+ assert False, 'unknown archive type'
281
+ else:
282
+ error(f'Missing input file or directory: {source}')
283
+
284
+ #----------------------------------------------------------------------------
285
+
286
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
287
+ dest_ext = file_ext(dest)
288
+
289
+ if dest_ext == 'zip':
290
+ if os.path.dirname(dest) != '':
291
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
292
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
293
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
294
+ zf.writestr(fname, data)
295
+ return '', zip_write_bytes, zf.close
296
+ else:
297
+ # If the output folder already exists, check that is is
298
+ # empty.
299
+ #
300
+ # Note: creating the output directory is not strictly
301
+ # necessary as folder_write_bytes() also mkdirs, but it's better
302
+ # to give an error message earlier in case the dest folder
303
+ # somehow cannot be created.
304
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
305
+ error('--dest folder must be empty')
306
+ os.makedirs(dest, exist_ok=True)
307
+
308
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
309
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
310
+ with open(fname, 'wb') as fout:
311
+ if isinstance(data, str):
312
+ data = data.encode('utf8')
313
+ fout.write(data)
314
+ return dest, folder_write_bytes, lambda: None
315
+
316
+ #----------------------------------------------------------------------------
317
+
318
+ @click.command()
319
+ @click.pass_context
320
+ @click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
321
+ @click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
322
+ @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
323
+ @click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
324
+ @click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
325
+ def convert_dataset(
326
+ ctx: click.Context,
327
+ source: str,
328
+ dest: str,
329
+ max_images: Optional[int],
330
+ transform: Optional[str],
331
+ resolution: Optional[Tuple[int, int]]
332
+ ):
333
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
334
+
335
+ The input dataset format is guessed from the --source argument:
336
+
337
+ \b
338
+ --source *_lmdb/ Load LSUN dataset
339
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
340
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
341
+ --source path/ Recursively load all images from path/
342
+ --source dataset.zip Recursively load all images from dataset.zip
343
+
344
+ Specifying the output format and path:
345
+
346
+ \b
347
+ --dest /path/to/dir Save output files under /path/to/dir
348
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
349
+
350
+ The output dataset format can be either an image folder or an uncompressed zip archive.
351
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
352
+ offer better training performance on network file systems.
353
+
354
+ Images within the dataset archive will be stored as uncompressed PNG.
355
+ Uncompresed PNGs can be efficiently decoded in the training loop.
356
+
357
+ Class labels are stored in a file called 'dataset.json' that is stored at the
358
+ dataset root folder. This file has the following structure:
359
+
360
+ \b
361
+ {
362
+ "labels": [
363
+ ["00000/img00000000.png",6],
364
+ ["00000/img00000001.png",9],
365
+ ... repeated for every image in the datase
366
+ ["00049/img00049999.png",1]
367
+ ]
368
+ }
369
+
370
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
371
+ not containing class labels.
372
+
373
+ Image scale/crop and resolution requirements:
374
+
375
+ Output images must be square-shaped and they must all have the same power-of-two
376
+ dimensions.
377
+
378
+ To scale arbitrary input image size to a specific width and height, use the
379
+ --resolution option. Output resolution will be either the original
380
+ input resolution (if resolution was not specified) or the one specified with
381
+ --resolution option.
382
+
383
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
384
+ center crop transform on the input image. These options should be used with the
385
+ --resolution option. For example:
386
+
387
+ \b
388
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
389
+ --transform=center-crop-wide --resolution=512x384
390
+ """
391
+
392
+ PIL.Image.init() # type: ignore
393
+
394
+ if dest == '':
395
+ ctx.fail('--dest output filename or directory must not be an empty string')
396
+
397
+ num_files, input_iter = open_dataset(source, max_images=max_images)
398
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
399
+
400
+ if resolution is None: resolution = (None, None)
401
+ transform_image = make_transform(transform, *resolution)
402
+
403
+ dataset_attrs = None
404
+
405
+ labels = []
406
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
407
+ idx_str = f'{idx:08d}'
408
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
409
+
410
+ # Apply crop and resize.
411
+ img = transform_image(image['img'])
412
+
413
+ # Transform may drop images.
414
+ if img is None:
415
+ continue
416
+
417
+ # Error check to require uniform image attributes across
418
+ # the whole dataset.
419
+ channels = img.shape[2] if img.ndim == 3 else 1
420
+ cur_image_attrs = {
421
+ 'width': img.shape[1],
422
+ 'height': img.shape[0],
423
+ 'channels': channels
424
+ }
425
+ if dataset_attrs is None:
426
+ dataset_attrs = cur_image_attrs
427
+ width = dataset_attrs['width']
428
+ height = dataset_attrs['height']
429
+ if width != height:
430
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
431
+ if dataset_attrs['channels'] not in [1, 3]:
432
+ error('Input images must be stored as RGB or grayscale')
433
+ if width != 2 ** int(np.floor(np.log2(width))):
434
+ error('Image width/height after scale and crop are required to be power-of-two')
435
+ elif dataset_attrs != cur_image_attrs:
436
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()] # pylint: disable=unsubscriptable-object
437
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
438
+
439
+ # Save the image as an uncompressed PNG.
440
+ img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
441
+ image_bits = io.BytesIO()
442
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
443
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
444
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
445
+
446
+ metadata = {
447
+ 'labels': labels if all(x is not None for x in labels) else None
448
+ }
449
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
450
+ close_dest()
451
+
452
+ #----------------------------------------------------------------------------
453
+
454
+ if __name__ == "__main__":
455
+ convert_dataset() # pylint: disable=no-value-for-parameter
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
environment.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: stylegan3
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python >= 3.8
7
+ - pip
8
+ - numpy>=1.20
9
+ - click>=8.0
10
+ - pillow=8.3.1
11
+ - scipy=1.7.1
12
+ - pytorch=1.9.1
13
+ - cudatoolkit=11.1
14
+ - requests=2.26.0
15
+ - tqdm=4.62.2
16
+ - ninja=1.10.2
17
+ - matplotlib=3.4.2
18
+ - imageio=2.9.0
19
+ - pip:
20
+ - imgui==1.3.0
21
+ - glfw==2.2.0
22
+ - pyopengl==3.1.5
23
+ - imageio-ffmpeg==0.4.3
24
+ - pyspng
gen_images.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def parse_range(s: Union[str, List]) -> List[int]:
26
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
27
+
28
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
29
+ '''
30
+ if isinstance(s, list): return s
31
+ ranges = []
32
+ range_re = re.compile(r'^(\d+)-(\d+)$')
33
+ for p in s.split(','):
34
+ if m := range_re.match(p):
35
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
36
+ else:
37
+ ranges.append(int(p))
38
+ return ranges
39
+
40
+ #----------------------------------------------------------------------------
41
+
42
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
43
+ '''Parse a floating point 2-vector of syntax 'a,b'.
44
+
45
+ Example:
46
+ '0,1' returns (0,1)
47
+ '''
48
+ if isinstance(s, tuple): return s
49
+ parts = s.split(',')
50
+ if len(parts) == 2:
51
+ return (float(parts[0]), float(parts[1]))
52
+ raise ValueError(f'cannot parse 2-vector {s}')
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def make_transform(translate: Tuple[float,float], angle: float):
57
+ m = np.eye(3)
58
+ s = np.sin(angle/360.0*np.pi*2)
59
+ c = np.cos(angle/360.0*np.pi*2)
60
+ m[0][0] = c
61
+ m[0][1] = s
62
+ m[0][2] = translate[0]
63
+ m[1][0] = -s
64
+ m[1][1] = c
65
+ m[1][2] = translate[1]
66
+ return m
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ @click.command()
71
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
72
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
73
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
74
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
75
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
76
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
77
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
78
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
79
+ def generate_images(
80
+ network_pkl: str,
81
+ seeds: List[int],
82
+ truncation_psi: float,
83
+ noise_mode: str,
84
+ outdir: str,
85
+ translate: Tuple[float,float],
86
+ rotate: float,
87
+ class_idx: Optional[int]
88
+ ):
89
+ """Generate images using pretrained network pickle.
90
+
91
+ Examples:
92
+
93
+ \b
94
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
95
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
96
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
97
+
98
+ \b
99
+ # Generate uncurated images with truncation using the MetFaces-U dataset
100
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
101
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
102
+ """
103
+
104
+ print('Loading networks from "%s"...' % network_pkl)
105
+ device = torch.device('cuda')
106
+ with dnnlib.util.open_url(network_pkl) as f:
107
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
108
+
109
+ os.makedirs(outdir, exist_ok=True)
110
+
111
+ # Labels.
112
+ label = torch.zeros([1, G.c_dim], device=device)
113
+ if G.c_dim != 0:
114
+ if class_idx is None:
115
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
116
+ label[:, class_idx] = 1
117
+ else:
118
+ if class_idx is not None:
119
+ print ('warn: --class=lbl ignored when running on an unconditional network')
120
+
121
+ # Generate images.
122
+ for seed_idx, seed in enumerate(seeds):
123
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
124
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
125
+
126
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
127
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
128
+ # operations in the network.
129
+ if hasattr(G.synthesis, 'input'):
130
+ m = make_transform(translate, rotate)
131
+ m = np.linalg.inv(m)
132
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
133
+
134
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
135
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
136
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
137
+
138
+
139
+ #----------------------------------------------------------------------------
140
+
141
+ if __name__ == "__main__":
142
+ generate_images() # pylint: disable=no-value-for-parameter
143
+
144
+ #----------------------------------------------------------------------------
gen_video.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate lerp videos using pretrained network pickle."""
10
+
11
+ import copy
12
+ import os
13
+ import re
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import click
17
+ import dnnlib
18
+ import imageio
19
+ import numpy as np
20
+ import scipy.interpolate
21
+ import torch
22
+ from tqdm import tqdm
23
+
24
+ import legacy
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True):
29
+ batch_size, channels, img_h, img_w = img.shape
30
+ if grid_w is None:
31
+ grid_w = batch_size // grid_h
32
+ assert batch_size == grid_w * grid_h
33
+ if float_to_uint8:
34
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
35
+ img = img.reshape(grid_h, grid_w, channels, img_h, img_w)
36
+ img = img.permute(2, 0, 3, 1, 4)
37
+ img = img.reshape(channels, grid_h * img_h, grid_w * img_w)
38
+ if chw_to_hwc:
39
+ img = img.permute(1, 2, 0)
40
+ if to_numpy:
41
+ img = img.cpu().numpy()
42
+ return img
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs):
47
+ grid_w = grid_dims[0]
48
+ grid_h = grid_dims[1]
49
+
50
+ if num_keyframes is None:
51
+ if len(seeds) % (grid_w*grid_h) != 0:
52
+ raise ValueError('Number of input seeds must be divisible by grid W*H')
53
+ num_keyframes = len(seeds) // (grid_w*grid_h)
54
+
55
+ all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64)
56
+ for idx in range(num_keyframes*grid_h*grid_w):
57
+ all_seeds[idx] = seeds[idx % len(seeds)]
58
+
59
+ if shuffle_seed is not None:
60
+ rng = np.random.RandomState(seed=shuffle_seed)
61
+ rng.shuffle(all_seeds)
62
+
63
+ zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device)
64
+ ws = G.mapping(z=zs, c=None, truncation_psi=psi)
65
+ _ = G.synthesis(ws[:1]) # warm up
66
+ ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:])
67
+
68
+ # Interpolation.
69
+ grid = []
70
+ for yi in range(grid_h):
71
+ row = []
72
+ for xi in range(grid_w):
73
+ x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1))
74
+ y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1])
75
+ interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0)
76
+ row.append(interp)
77
+ grid.append(row)
78
+
79
+ # Render video.
80
+ video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs)
81
+ for frame_idx in tqdm(range(num_keyframes * w_frames)):
82
+ imgs = []
83
+ for yi in range(grid_h):
84
+ for xi in range(grid_w):
85
+ interp = grid[yi][xi]
86
+ w = torch.from_numpy(interp(frame_idx / w_frames)).to(device)
87
+ img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0]
88
+ imgs.append(img)
89
+ video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h))
90
+ video_out.close()
91
+
92
+ #----------------------------------------------------------------------------
93
+
94
+ def parse_range(s: Union[str, List[int]]) -> List[int]:
95
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
96
+
97
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
98
+ '''
99
+ if isinstance(s, list): return s
100
+ ranges = []
101
+ range_re = re.compile(r'^(\d+)-(\d+)$')
102
+ for p in s.split(','):
103
+ if m := range_re.match(p):
104
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
105
+ else:
106
+ ranges.append(int(p))
107
+ return ranges
108
+
109
+ #----------------------------------------------------------------------------
110
+
111
+ def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]:
112
+ '''Parse a 'M,N' or 'MxN' integer tuple.
113
+
114
+ Example:
115
+ '4x2' returns (4,2)
116
+ '0,1' returns (0,1)
117
+ '''
118
+ if isinstance(s, tuple): return s
119
+ if m := re.match(r'^(\d+)[x,](\d+)$', s):
120
+ return (int(m.group(1)), int(m.group(2)))
121
+ raise ValueError(f'cannot parse tuple {s}')
122
+
123
+ #----------------------------------------------------------------------------
124
+
125
+ @click.command()
126
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
127
+ @click.option('--seeds', type=parse_range, help='List of random seeds', required=True)
128
+ @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None)
129
+ @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1))
130
+ @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None)
131
+ @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120)
132
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
133
+ @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE')
134
+ def generate_images(
135
+ network_pkl: str,
136
+ seeds: List[int],
137
+ shuffle_seed: Optional[int],
138
+ truncation_psi: float,
139
+ grid: Tuple[int,int],
140
+ num_keyframes: Optional[int],
141
+ w_frames: int,
142
+ output: str
143
+ ):
144
+ """Render a latent vector interpolation video.
145
+
146
+ Examples:
147
+
148
+ \b
149
+ # Render a 4x2 grid of interpolations for seeds 0 through 31.
150
+ python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\
151
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
152
+
153
+ Animation length and seed keyframes:
154
+
155
+ The animation length is either determined based on the --seeds value or explicitly
156
+ specified using the --num-keyframes option.
157
+
158
+ When num keyframes is specified with --num-keyframes, the output video length
159
+ will be 'num_keyframes*w_frames' frames.
160
+
161
+ If --num-keyframes is not specified, the number of seeds given with
162
+ --seeds must be divisible by grid size W*H (--grid). In this case the
163
+ output video length will be '# seeds/(w*h)*w_frames' frames.
164
+ """
165
+
166
+ print('Loading networks from "%s"...' % network_pkl)
167
+ device = torch.device('cuda')
168
+ with dnnlib.util.open_url(network_pkl) as f:
169
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
170
+
171
+ gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi)
172
+
173
+ #----------------------------------------------------------------------------
174
+
175
+ if __name__ == "__main__":
176
+ generate_images() # pylint: disable=no-value-for-parameter
177
+
178
+ #----------------------------------------------------------------------------
gui_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
gui_utils/gl_utils.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import functools
11
+ import contextlib
12
+ import numpy as np
13
+ import OpenGL.GL as gl
14
+ import OpenGL.GL.ARB.texture_float
15
+ import dnnlib
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def init_egl():
20
+ assert os.environ['PYOPENGL_PLATFORM'] == 'egl' # Must be set before importing OpenGL.
21
+ import OpenGL.EGL as egl
22
+ import ctypes
23
+
24
+ # Initialize EGL.
25
+ display = egl.eglGetDisplay(egl.EGL_DEFAULT_DISPLAY)
26
+ assert display != egl.EGL_NO_DISPLAY
27
+ major = ctypes.c_int32()
28
+ minor = ctypes.c_int32()
29
+ ok = egl.eglInitialize(display, major, minor)
30
+ assert ok
31
+ assert major.value * 10 + minor.value >= 14
32
+
33
+ # Choose config.
34
+ config_attribs = [
35
+ egl.EGL_RENDERABLE_TYPE, egl.EGL_OPENGL_BIT,
36
+ egl.EGL_SURFACE_TYPE, egl.EGL_PBUFFER_BIT,
37
+ egl.EGL_NONE
38
+ ]
39
+ configs = (ctypes.c_int32 * 1)()
40
+ num_configs = ctypes.c_int32()
41
+ ok = egl.eglChooseConfig(display, config_attribs, configs, 1, num_configs)
42
+ assert ok
43
+ assert num_configs.value == 1
44
+ config = configs[0]
45
+
46
+ # Create dummy pbuffer surface.
47
+ surface_attribs = [
48
+ egl.EGL_WIDTH, 1,
49
+ egl.EGL_HEIGHT, 1,
50
+ egl.EGL_NONE
51
+ ]
52
+ surface = egl.eglCreatePbufferSurface(display, config, surface_attribs)
53
+ assert surface != egl.EGL_NO_SURFACE
54
+
55
+ # Setup GL context.
56
+ ok = egl.eglBindAPI(egl.EGL_OPENGL_API)
57
+ assert ok
58
+ context = egl.eglCreateContext(display, config, egl.EGL_NO_CONTEXT, None)
59
+ assert context != egl.EGL_NO_CONTEXT
60
+ ok = egl.eglMakeCurrent(display, surface, surface, context)
61
+ assert ok
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ _texture_formats = {
66
+ ('uint8', 1): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE, internalformat=gl.GL_LUMINANCE8),
67
+ ('uint8', 2): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_LUMINANCE_ALPHA, internalformat=gl.GL_LUMINANCE8_ALPHA8),
68
+ ('uint8', 3): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGB, internalformat=gl.GL_RGB8),
69
+ ('uint8', 4): dnnlib.EasyDict(type=gl.GL_UNSIGNED_BYTE, format=gl.GL_RGBA, internalformat=gl.GL_RGBA8),
70
+ ('float32', 1): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE32F_ARB),
71
+ ('float32', 2): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_LUMINANCE_ALPHA, internalformat=OpenGL.GL.ARB.texture_float.GL_LUMINANCE_ALPHA32F_ARB),
72
+ ('float32', 3): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGB, internalformat=gl.GL_RGB32F),
73
+ ('float32', 4): dnnlib.EasyDict(type=gl.GL_FLOAT, format=gl.GL_RGBA, internalformat=gl.GL_RGBA32F),
74
+ }
75
+
76
+ def get_texture_format(dtype, channels):
77
+ return _texture_formats[(np.dtype(dtype).name, int(channels))]
78
+
79
+ #----------------------------------------------------------------------------
80
+
81
+ def prepare_texture_data(image):
82
+ image = np.asarray(image)
83
+ if image.ndim == 2:
84
+ image = image[:, :, np.newaxis]
85
+ if image.dtype.name == 'float64':
86
+ image = image.astype('float32')
87
+ return image
88
+
89
+ #----------------------------------------------------------------------------
90
+
91
+ def draw_pixels(image, *, pos=0, zoom=1, align=0, rint=True):
92
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
93
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
94
+ align = np.broadcast_to(np.asarray(align, dtype='float32'), [2])
95
+ image = prepare_texture_data(image)
96
+ height, width, channels = image.shape
97
+ size = zoom * [width, height]
98
+ pos = pos - size * align
99
+ if rint:
100
+ pos = np.rint(pos)
101
+ fmt = get_texture_format(image.dtype, channels)
102
+
103
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_PIXEL_MODE_BIT)
104
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
105
+ gl.glRasterPos2f(pos[0], pos[1])
106
+ gl.glPixelZoom(zoom[0], -zoom[1])
107
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
108
+ gl.glDrawPixels(width, height, fmt.format, fmt.type, image)
109
+ gl.glPopClientAttrib()
110
+ gl.glPopAttrib()
111
+
112
+ #----------------------------------------------------------------------------
113
+
114
+ def read_pixels(width, height, *, pos=0, dtype='uint8', channels=3):
115
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
116
+ dtype = np.dtype(dtype)
117
+ fmt = get_texture_format(dtype, channels)
118
+ image = np.empty([height, width, channels], dtype=dtype)
119
+
120
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
121
+ gl.glPixelStorei(gl.GL_PACK_ALIGNMENT, 1)
122
+ gl.glReadPixels(int(np.round(pos[0])), int(np.round(pos[1])), width, height, fmt.format, fmt.type, image)
123
+ gl.glPopClientAttrib()
124
+ return np.flipud(image)
125
+
126
+ #----------------------------------------------------------------------------
127
+
128
+ class Texture:
129
+ def __init__(self, *, image=None, width=None, height=None, channels=None, dtype=None, bilinear=True, mipmap=True):
130
+ self.gl_id = None
131
+ self.bilinear = bilinear
132
+ self.mipmap = mipmap
133
+
134
+ # Determine size and dtype.
135
+ if image is not None:
136
+ image = prepare_texture_data(image)
137
+ self.height, self.width, self.channels = image.shape
138
+ self.dtype = image.dtype
139
+ else:
140
+ assert width is not None and height is not None
141
+ self.width = width
142
+ self.height = height
143
+ self.channels = channels if channels is not None else 3
144
+ self.dtype = np.dtype(dtype) if dtype is not None else np.uint8
145
+
146
+ # Validate size and dtype.
147
+ assert isinstance(self.width, int) and self.width >= 0
148
+ assert isinstance(self.height, int) and self.height >= 0
149
+ assert isinstance(self.channels, int) and self.channels >= 1
150
+ assert self.is_compatible(width=width, height=height, channels=channels, dtype=dtype)
151
+
152
+ # Create texture object.
153
+ self.gl_id = gl.glGenTextures(1)
154
+ with self.bind():
155
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE)
156
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE)
157
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR if self.bilinear else gl.GL_NEAREST)
158
+ gl.glTexParameterf(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR_MIPMAP_LINEAR if self.mipmap else gl.GL_NEAREST)
159
+ self.update(image)
160
+
161
+ def delete(self):
162
+ if self.gl_id is not None:
163
+ gl.glDeleteTextures([self.gl_id])
164
+ self.gl_id = None
165
+
166
+ def __del__(self):
167
+ try:
168
+ self.delete()
169
+ except:
170
+ pass
171
+
172
+ @contextlib.contextmanager
173
+ def bind(self):
174
+ prev_id = gl.glGetInteger(gl.GL_TEXTURE_BINDING_2D)
175
+ gl.glBindTexture(gl.GL_TEXTURE_2D, self.gl_id)
176
+ yield
177
+ gl.glBindTexture(gl.GL_TEXTURE_2D, prev_id)
178
+
179
+ def update(self, image):
180
+ if image is not None:
181
+ image = prepare_texture_data(image)
182
+ assert self.is_compatible(image=image)
183
+ with self.bind():
184
+ fmt = get_texture_format(self.dtype, self.channels)
185
+ gl.glPushClientAttrib(gl.GL_CLIENT_PIXEL_STORE_BIT)
186
+ gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)
187
+ gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, fmt.internalformat, self.width, self.height, 0, fmt.format, fmt.type, image)
188
+ if self.mipmap:
189
+ gl.glGenerateMipmap(gl.GL_TEXTURE_2D)
190
+ gl.glPopClientAttrib()
191
+
192
+ def draw(self, *, pos=0, zoom=1, align=0, rint=False, color=1, alpha=1, rounding=0):
193
+ zoom = np.broadcast_to(np.asarray(zoom, dtype='float32'), [2])
194
+ size = zoom * [self.width, self.height]
195
+ with self.bind():
196
+ gl.glPushAttrib(gl.GL_ENABLE_BIT)
197
+ gl.glEnable(gl.GL_TEXTURE_2D)
198
+ draw_rect(pos=pos, size=size, align=align, rint=rint, color=color, alpha=alpha, rounding=rounding)
199
+ gl.glPopAttrib()
200
+
201
+ def is_compatible(self, *, image=None, width=None, height=None, channels=None, dtype=None): # pylint: disable=too-many-return-statements
202
+ if image is not None:
203
+ if image.ndim != 3:
204
+ return False
205
+ ih, iw, ic = image.shape
206
+ if not self.is_compatible(width=iw, height=ih, channels=ic, dtype=image.dtype):
207
+ return False
208
+ if width is not None and self.width != width:
209
+ return False
210
+ if height is not None and self.height != height:
211
+ return False
212
+ if channels is not None and self.channels != channels:
213
+ return False
214
+ if dtype is not None and self.dtype != dtype:
215
+ return False
216
+ return True
217
+
218
+ #----------------------------------------------------------------------------
219
+
220
+ class Framebuffer:
221
+ def __init__(self, *, texture=None, width=None, height=None, channels=None, dtype=None, msaa=0):
222
+ self.texture = texture
223
+ self.gl_id = None
224
+ self.gl_color = None
225
+ self.gl_depth_stencil = None
226
+ self.msaa = msaa
227
+
228
+ # Determine size and dtype.
229
+ if texture is not None:
230
+ assert isinstance(self.texture, Texture)
231
+ self.width = texture.width
232
+ self.height = texture.height
233
+ self.channels = texture.channels
234
+ self.dtype = texture.dtype
235
+ else:
236
+ assert width is not None and height is not None
237
+ self.width = width
238
+ self.height = height
239
+ self.channels = channels if channels is not None else 4
240
+ self.dtype = np.dtype(dtype) if dtype is not None else np.float32
241
+
242
+ # Validate size and dtype.
243
+ assert isinstance(self.width, int) and self.width >= 0
244
+ assert isinstance(self.height, int) and self.height >= 0
245
+ assert isinstance(self.channels, int) and self.channels >= 1
246
+ assert width is None or width == self.width
247
+ assert height is None or height == self.height
248
+ assert channels is None or channels == self.channels
249
+ assert dtype is None or dtype == self.dtype
250
+
251
+ # Create framebuffer object.
252
+ self.gl_id = gl.glGenFramebuffers(1)
253
+ with self.bind():
254
+
255
+ # Setup color buffer.
256
+ if self.texture is not None:
257
+ assert self.msaa == 0
258
+ gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, self.texture.gl_id, 0)
259
+ else:
260
+ fmt = get_texture_format(self.dtype, self.channels)
261
+ self.gl_color = gl.glGenRenderbuffers(1)
262
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_color)
263
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, fmt.internalformat, self.width, self.height)
264
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, self.gl_color)
265
+
266
+ # Setup depth/stencil buffer.
267
+ self.gl_depth_stencil = gl.glGenRenderbuffers(1)
268
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, self.gl_depth_stencil)
269
+ gl.glRenderbufferStorageMultisample(gl.GL_RENDERBUFFER, self.msaa, gl.GL_DEPTH24_STENCIL8, self.width, self.height)
270
+ gl.glFramebufferRenderbuffer(gl.GL_FRAMEBUFFER, gl.GL_DEPTH_STENCIL_ATTACHMENT, gl.GL_RENDERBUFFER, self.gl_depth_stencil)
271
+
272
+ def delete(self):
273
+ if self.gl_id is not None:
274
+ gl.glDeleteFramebuffers([self.gl_id])
275
+ self.gl_id = None
276
+ if self.gl_color is not None:
277
+ gl.glDeleteRenderbuffers(1, [self.gl_color])
278
+ self.gl_color = None
279
+ if self.gl_depth_stencil is not None:
280
+ gl.glDeleteRenderbuffers(1, [self.gl_depth_stencil])
281
+ self.gl_depth_stencil = None
282
+
283
+ def __del__(self):
284
+ try:
285
+ self.delete()
286
+ except:
287
+ pass
288
+
289
+ @contextlib.contextmanager
290
+ def bind(self):
291
+ prev_fbo = gl.glGetInteger(gl.GL_FRAMEBUFFER_BINDING)
292
+ prev_rbo = gl.glGetInteger(gl.GL_RENDERBUFFER_BINDING)
293
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.gl_id)
294
+ if self.width is not None and self.height is not None:
295
+ gl.glViewport(0, 0, self.width, self.height)
296
+ yield
297
+ gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, prev_fbo)
298
+ gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, prev_rbo)
299
+
300
+ def blit(self, dst=None):
301
+ assert dst is None or isinstance(dst, Framebuffer)
302
+ with self.bind():
303
+ gl.glBindFramebuffer(gl.GL_DRAW_FRAMEBUFFER, 0 if dst is None else dst.fbo)
304
+ gl.glBlitFramebuffer(0, 0, self.width, self.height, 0, 0, self.width, self.height, gl.GL_COLOR_BUFFER_BIT, gl.GL_NEAREST)
305
+
306
+ #----------------------------------------------------------------------------
307
+
308
+ def draw_shape(vertices, *, mode=gl.GL_TRIANGLE_FAN, pos=0, size=1, color=1, alpha=1):
309
+ assert vertices.ndim == 2 and vertices.shape[1] == 2
310
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
311
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2])
312
+ color = np.broadcast_to(np.asarray(color, dtype='float32'), [3])
313
+ alpha = np.clip(np.broadcast_to(np.asarray(alpha, dtype='float32'), []), 0, 1)
314
+
315
+ gl.glPushClientAttrib(gl.GL_CLIENT_VERTEX_ARRAY_BIT)
316
+ gl.glPushAttrib(gl.GL_CURRENT_BIT | gl.GL_TRANSFORM_BIT)
317
+ gl.glMatrixMode(gl.GL_MODELVIEW)
318
+ gl.glPushMatrix()
319
+
320
+ gl.glEnableClientState(gl.GL_VERTEX_ARRAY)
321
+ gl.glEnableClientState(gl.GL_TEXTURE_COORD_ARRAY)
322
+ gl.glVertexPointer(2, gl.GL_FLOAT, 0, vertices)
323
+ gl.glTexCoordPointer(2, gl.GL_FLOAT, 0, vertices)
324
+ gl.glTranslate(pos[0], pos[1], 0)
325
+ gl.glScale(size[0], size[1], 1)
326
+ gl.glColor4f(color[0] * alpha, color[1] * alpha, color[2] * alpha, alpha)
327
+ gl.glDrawArrays(mode, 0, vertices.shape[0])
328
+
329
+ gl.glPopMatrix()
330
+ gl.glPopAttrib()
331
+ gl.glPopClientAttrib()
332
+
333
+ #----------------------------------------------------------------------------
334
+
335
+ def draw_rect(*, pos=0, pos2=None, size=None, align=0, rint=False, color=1, alpha=1, rounding=0):
336
+ assert pos2 is None or size is None
337
+ pos = np.broadcast_to(np.asarray(pos, dtype='float32'), [2])
338
+ pos2 = np.broadcast_to(np.asarray(pos2, dtype='float32'), [2]) if pos2 is not None else None
339
+ size = np.broadcast_to(np.asarray(size, dtype='float32'), [2]) if size is not None else None
340
+ size = size if size is not None else pos2 - pos if pos2 is not None else np.array([1, 1], dtype='float32')
341
+ pos = pos - size * align
342
+ if rint:
343
+ pos = np.rint(pos)
344
+ rounding = np.broadcast_to(np.asarray(rounding, dtype='float32'), [2])
345
+ rounding = np.minimum(np.abs(rounding) / np.maximum(np.abs(size), 1e-8), 0.5)
346
+ if np.min(rounding) == 0:
347
+ rounding *= 0
348
+ vertices = _setup_rect(float(rounding[0]), float(rounding[1]))
349
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_FAN, pos=pos, size=size, color=color, alpha=alpha)
350
+
351
+ @functools.lru_cache(maxsize=10000)
352
+ def _setup_rect(rx, ry):
353
+ t = np.linspace(0, np.pi / 2, 1 if max(rx, ry) == 0 else 64)
354
+ s = 1 - np.sin(t); c = 1 - np.cos(t)
355
+ x = [c * rx, 1 - s * rx, 1 - c * rx, s * rx]
356
+ y = [s * ry, c * ry, 1 - s * ry, 1 - c * ry]
357
+ v = np.stack([x, y], axis=-1).reshape(-1, 2)
358
+ return v.astype('float32')
359
+
360
+ #----------------------------------------------------------------------------
361
+
362
+ def draw_circle(*, center=0, radius=100, hole=0, color=1, alpha=1):
363
+ hole = np.broadcast_to(np.asarray(hole, dtype='float32'), [])
364
+ vertices = _setup_circle(float(hole))
365
+ draw_shape(vertices, mode=gl.GL_TRIANGLE_STRIP, pos=center, size=radius, color=color, alpha=alpha)
366
+
367
+ @functools.lru_cache(maxsize=10000)
368
+ def _setup_circle(hole):
369
+ t = np.linspace(0, np.pi * 2, 128)
370
+ s = np.sin(t); c = np.cos(t)
371
+ v = np.stack([c, s, c * hole, s * hole], axis=-1).reshape(-1, 2)
372
+ return v.astype('float32')
373
+
374
+ #----------------------------------------------------------------------------
gui_utils/glfw_window.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import time
10
+ import glfw
11
+ import OpenGL.GL as gl
12
+ from . import gl_utils
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ class GlfwWindow: # pylint: disable=too-many-public-methods
17
+ def __init__(self, *, title='GlfwWindow', window_width=1920, window_height=1080, deferred_show=True, close_on_esc=True):
18
+ self._glfw_window = None
19
+ self._drawing_frame = False
20
+ self._frame_start_time = None
21
+ self._frame_delta = 0
22
+ self._fps_limit = None
23
+ self._vsync = None
24
+ self._skip_frames = 0
25
+ self._deferred_show = deferred_show
26
+ self._close_on_esc = close_on_esc
27
+ self._esc_pressed = False
28
+ self._drag_and_drop_paths = None
29
+ self._capture_next_frame = False
30
+ self._captured_frame = None
31
+
32
+ # Create window.
33
+ glfw.init()
34
+ glfw.window_hint(glfw.VISIBLE, False)
35
+ self._glfw_window = glfw.create_window(width=window_width, height=window_height, title=title, monitor=None, share=None)
36
+ self._attach_glfw_callbacks()
37
+ self.make_context_current()
38
+
39
+ # Adjust window.
40
+ self.set_vsync(False)
41
+ self.set_window_size(window_width, window_height)
42
+ if not self._deferred_show:
43
+ glfw.show_window(self._glfw_window)
44
+
45
+ def close(self):
46
+ if self._drawing_frame:
47
+ self.end_frame()
48
+ if self._glfw_window is not None:
49
+ glfw.destroy_window(self._glfw_window)
50
+ self._glfw_window = None
51
+ #glfw.terminate() # Commented out to play it nice with other glfw clients.
52
+
53
+ def __del__(self):
54
+ try:
55
+ self.close()
56
+ except:
57
+ pass
58
+
59
+ @property
60
+ def window_width(self):
61
+ return self.content_width
62
+
63
+ @property
64
+ def window_height(self):
65
+ return self.content_height + self.title_bar_height
66
+
67
+ @property
68
+ def content_width(self):
69
+ width, _height = glfw.get_window_size(self._glfw_window)
70
+ return width
71
+
72
+ @property
73
+ def content_height(self):
74
+ _width, height = glfw.get_window_size(self._glfw_window)
75
+ return height
76
+
77
+ @property
78
+ def title_bar_height(self):
79
+ _left, top, _right, _bottom = glfw.get_window_frame_size(self._glfw_window)
80
+ return top
81
+
82
+ @property
83
+ def monitor_width(self):
84
+ _, _, width, _height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
85
+ return width
86
+
87
+ @property
88
+ def monitor_height(self):
89
+ _, _, _width, height = glfw.get_monitor_workarea(glfw.get_primary_monitor())
90
+ return height
91
+
92
+ @property
93
+ def frame_delta(self):
94
+ return self._frame_delta
95
+
96
+ def set_title(self, title):
97
+ glfw.set_window_title(self._glfw_window, title)
98
+
99
+ def set_window_size(self, width, height):
100
+ width = min(width, self.monitor_width)
101
+ height = min(height, self.monitor_height)
102
+ glfw.set_window_size(self._glfw_window, width, max(height - self.title_bar_height, 0))
103
+ if width == self.monitor_width and height == self.monitor_height:
104
+ self.maximize()
105
+
106
+ def set_content_size(self, width, height):
107
+ self.set_window_size(width, height + self.title_bar_height)
108
+
109
+ def maximize(self):
110
+ glfw.maximize_window(self._glfw_window)
111
+
112
+ def set_position(self, x, y):
113
+ glfw.set_window_pos(self._glfw_window, x, y + self.title_bar_height)
114
+
115
+ def center(self):
116
+ self.set_position((self.monitor_width - self.window_width) // 2, (self.monitor_height - self.window_height) // 2)
117
+
118
+ def set_vsync(self, vsync):
119
+ vsync = bool(vsync)
120
+ if vsync != self._vsync:
121
+ glfw.swap_interval(1 if vsync else 0)
122
+ self._vsync = vsync
123
+
124
+ def set_fps_limit(self, fps_limit):
125
+ self._fps_limit = int(fps_limit)
126
+
127
+ def should_close(self):
128
+ return glfw.window_should_close(self._glfw_window) or (self._close_on_esc and self._esc_pressed)
129
+
130
+ def skip_frame(self):
131
+ self.skip_frames(1)
132
+
133
+ def skip_frames(self, num): # Do not update window for the next N frames.
134
+ self._skip_frames = max(self._skip_frames, int(num))
135
+
136
+ def is_skipping_frames(self):
137
+ return self._skip_frames > 0
138
+
139
+ def capture_next_frame(self):
140
+ self._capture_next_frame = True
141
+
142
+ def pop_captured_frame(self):
143
+ frame = self._captured_frame
144
+ self._captured_frame = None
145
+ return frame
146
+
147
+ def pop_drag_and_drop_paths(self):
148
+ paths = self._drag_and_drop_paths
149
+ self._drag_and_drop_paths = None
150
+ return paths
151
+
152
+ def draw_frame(self): # To be overridden by subclass.
153
+ self.begin_frame()
154
+ # Rendering code goes here.
155
+ self.end_frame()
156
+
157
+ def make_context_current(self):
158
+ if self._glfw_window is not None:
159
+ glfw.make_context_current(self._glfw_window)
160
+
161
+ def begin_frame(self):
162
+ # End previous frame.
163
+ if self._drawing_frame:
164
+ self.end_frame()
165
+
166
+ # Apply FPS limit.
167
+ if self._frame_start_time is not None and self._fps_limit is not None:
168
+ delay = self._frame_start_time - time.perf_counter() + 1 / self._fps_limit
169
+ if delay > 0:
170
+ time.sleep(delay)
171
+ cur_time = time.perf_counter()
172
+ if self._frame_start_time is not None:
173
+ self._frame_delta = cur_time - self._frame_start_time
174
+ self._frame_start_time = cur_time
175
+
176
+ # Process events.
177
+ glfw.poll_events()
178
+
179
+ # Begin frame.
180
+ self._drawing_frame = True
181
+ self.make_context_current()
182
+
183
+ # Initialize GL state.
184
+ gl.glViewport(0, 0, self.content_width, self.content_height)
185
+ gl.glMatrixMode(gl.GL_PROJECTION)
186
+ gl.glLoadIdentity()
187
+ gl.glTranslate(-1, 1, 0)
188
+ gl.glScale(2 / max(self.content_width, 1), -2 / max(self.content_height, 1), 1)
189
+ gl.glMatrixMode(gl.GL_MODELVIEW)
190
+ gl.glLoadIdentity()
191
+ gl.glEnable(gl.GL_BLEND)
192
+ gl.glBlendFunc(gl.GL_ONE, gl.GL_ONE_MINUS_SRC_ALPHA) # Pre-multiplied alpha.
193
+
194
+ # Clear.
195
+ gl.glClearColor(0, 0, 0, 1)
196
+ gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
197
+
198
+ def end_frame(self):
199
+ assert self._drawing_frame
200
+ self._drawing_frame = False
201
+
202
+ # Skip frames if requested.
203
+ if self._skip_frames > 0:
204
+ self._skip_frames -= 1
205
+ return
206
+
207
+ # Capture frame if requested.
208
+ if self._capture_next_frame:
209
+ self._captured_frame = gl_utils.read_pixels(self.content_width, self.content_height)
210
+ self._capture_next_frame = False
211
+
212
+ # Update window.
213
+ if self._deferred_show:
214
+ glfw.show_window(self._glfw_window)
215
+ self._deferred_show = False
216
+ glfw.swap_buffers(self._glfw_window)
217
+
218
+ def _attach_glfw_callbacks(self):
219
+ glfw.set_key_callback(self._glfw_window, self._glfw_key_callback)
220
+ glfw.set_drop_callback(self._glfw_window, self._glfw_drop_callback)
221
+
222
+ def _glfw_key_callback(self, _window, key, _scancode, action, _mods):
223
+ if action == glfw.PRESS and key == glfw.KEY_ESCAPE:
224
+ self._esc_pressed = True
225
+
226
+ def _glfw_drop_callback(self, _window, paths):
227
+ self._drag_and_drop_paths = paths
228
+
229
+ #----------------------------------------------------------------------------
gui_utils/imgui_utils.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import contextlib
10
+ import imgui
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27):
15
+ s = imgui.get_style()
16
+ s.window_padding = [spacing, spacing]
17
+ s.item_spacing = [spacing, spacing]
18
+ s.item_inner_spacing = [spacing, spacing]
19
+ s.columns_min_spacing = spacing
20
+ s.indent_spacing = indent
21
+ s.scrollbar_size = scrollbar
22
+ s.frame_padding = [4, 3]
23
+ s.window_border_size = 1
24
+ s.child_border_size = 1
25
+ s.popup_border_size = 1
26
+ s.frame_border_size = 1
27
+ s.window_rounding = 0
28
+ s.child_rounding = 0
29
+ s.popup_rounding = 3
30
+ s.frame_rounding = 3
31
+ s.scrollbar_rounding = 3
32
+ s.grab_rounding = 3
33
+
34
+ getattr(imgui, f'style_colors_{color_scheme}')(s)
35
+ c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
36
+ c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND]
37
+ s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1]
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ @contextlib.contextmanager
42
+ def grayed_out(cond=True):
43
+ if cond:
44
+ s = imgui.get_style()
45
+ text = s.colors[imgui.COLOR_TEXT_DISABLED]
46
+ grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB]
47
+ back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND]
48
+ imgui.push_style_color(imgui.COLOR_TEXT, *text)
49
+ imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab)
50
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab)
51
+ imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab)
52
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back)
53
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back)
54
+ imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back)
55
+ imgui.push_style_color(imgui.COLOR_BUTTON, *back)
56
+ imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back)
57
+ imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back)
58
+ imgui.push_style_color(imgui.COLOR_HEADER, *back)
59
+ imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back)
60
+ imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back)
61
+ imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back)
62
+ yield
63
+ imgui.pop_style_color(14)
64
+ else:
65
+ yield
66
+
67
+ #----------------------------------------------------------------------------
68
+
69
+ @contextlib.contextmanager
70
+ def item_width(width=None):
71
+ if width is not None:
72
+ imgui.push_item_width(width)
73
+ yield
74
+ imgui.pop_item_width()
75
+ else:
76
+ yield
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ def scoped_by_object_id(method):
81
+ def decorator(self, *args, **kwargs):
82
+ imgui.push_id(str(id(self)))
83
+ res = method(self, *args, **kwargs)
84
+ imgui.pop_id()
85
+ return res
86
+ return decorator
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ def button(label, width=0, enabled=True):
91
+ with grayed_out(not enabled):
92
+ clicked = imgui.button(label, width=width)
93
+ clicked = clicked and enabled
94
+ return clicked
95
+
96
+ #----------------------------------------------------------------------------
97
+
98
+ def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True):
99
+ expanded = False
100
+ if show:
101
+ if default:
102
+ flags |= imgui.TREE_NODE_DEFAULT_OPEN
103
+ if not enabled:
104
+ flags |= imgui.TREE_NODE_LEAF
105
+ with grayed_out(not enabled):
106
+ expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags)
107
+ expanded = expanded and enabled
108
+ return expanded, visible
109
+
110
+ #----------------------------------------------------------------------------
111
+
112
+ def popup_button(label, width=0, enabled=True):
113
+ if button(label, width, enabled):
114
+ imgui.open_popup(label)
115
+ opened = imgui.begin_popup(label)
116
+ return opened
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ def input_text(label, value, buffer_length, flags, width=None, help_text=''):
121
+ old_value = value
122
+ color = list(imgui.get_style().colors[imgui.COLOR_TEXT])
123
+ if value == '':
124
+ color[-1] *= 0.5
125
+ with item_width(width):
126
+ imgui.push_style_color(imgui.COLOR_TEXT, *color)
127
+ value = value if value != '' else help_text
128
+ changed, value = imgui.input_text(label, value, buffer_length, flags)
129
+ value = value if value != help_text else ''
130
+ imgui.pop_style_color(1)
131
+ if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE:
132
+ changed = (value != old_value)
133
+ return changed, value
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ def drag_previous_control(enabled=True):
138
+ dragging = False
139
+ dx = 0
140
+ dy = 0
141
+ if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP):
142
+ if enabled:
143
+ dragging = True
144
+ dx, dy = imgui.get_mouse_drag_delta()
145
+ imgui.reset_mouse_drag_delta()
146
+ imgui.end_drag_drop_source()
147
+ return dragging, dx, dy
148
+
149
+ #----------------------------------------------------------------------------
150
+
151
+ def drag_button(label, width=0, enabled=True):
152
+ clicked = button(label, width=width, enabled=enabled)
153
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
154
+ return clicked, dragging, dx, dy
155
+
156
+ #----------------------------------------------------------------------------
157
+
158
+ def drag_hidden_window(label, x, y, width, height, enabled=True):
159
+ imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0)
160
+ imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0)
161
+ imgui.set_next_window_position(x, y)
162
+ imgui.set_next_window_size(width, height)
163
+ imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE))
164
+ dragging, dx, dy = drag_previous_control(enabled=enabled)
165
+ imgui.end()
166
+ imgui.pop_style_color(2)
167
+ return dragging, dx, dy
168
+
169
+ #----------------------------------------------------------------------------
gui_utils/imgui_window.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import imgui
11
+ import imgui.integrations.glfw
12
+
13
+ from . import glfw_window
14
+ from . import imgui_utils
15
+ from . import text_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ class ImguiWindow(glfw_window.GlfwWindow):
20
+ def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs):
21
+ if font is None:
22
+ font = text_utils.get_default_font()
23
+ font_sizes = {int(size) for size in font_sizes}
24
+ super().__init__(title=title, **glfw_kwargs)
25
+
26
+ # Init fields.
27
+ self._imgui_context = None
28
+ self._imgui_renderer = None
29
+ self._imgui_fonts = None
30
+ self._cur_font_size = max(font_sizes)
31
+
32
+ # Delete leftover imgui.ini to avoid unexpected behavior.
33
+ if os.path.isfile('imgui.ini'):
34
+ os.remove('imgui.ini')
35
+
36
+ # Init ImGui.
37
+ self._imgui_context = imgui.create_context()
38
+ self._imgui_renderer = _GlfwRenderer(self._glfw_window)
39
+ self._attach_glfw_callbacks()
40
+ imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime.
41
+ imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom().
42
+ self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes}
43
+ self._imgui_renderer.refresh_font_texture()
44
+
45
+ def close(self):
46
+ self.make_context_current()
47
+ self._imgui_fonts = None
48
+ if self._imgui_renderer is not None:
49
+ self._imgui_renderer.shutdown()
50
+ self._imgui_renderer = None
51
+ if self._imgui_context is not None:
52
+ #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end.
53
+ self._imgui_context = None
54
+ super().close()
55
+
56
+ def _glfw_key_callback(self, *args):
57
+ super()._glfw_key_callback(*args)
58
+ self._imgui_renderer.keyboard_callback(*args)
59
+
60
+ @property
61
+ def font_size(self):
62
+ return self._cur_font_size
63
+
64
+ @property
65
+ def spacing(self):
66
+ return round(self._cur_font_size * 0.4)
67
+
68
+ def set_font_size(self, target): # Applied on next frame.
69
+ self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1]
70
+
71
+ def begin_frame(self):
72
+ # Begin glfw frame.
73
+ super().begin_frame()
74
+
75
+ # Process imgui events.
76
+ self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10
77
+ if self.content_width > 0 and self.content_height > 0:
78
+ self._imgui_renderer.process_inputs()
79
+
80
+ # Begin imgui frame.
81
+ imgui.new_frame()
82
+ imgui.push_font(self._imgui_fonts[self._cur_font_size])
83
+ imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4)
84
+
85
+ def end_frame(self):
86
+ imgui.pop_font()
87
+ imgui.render()
88
+ imgui.end_frame()
89
+ self._imgui_renderer.render(imgui.get_draw_data())
90
+ super().end_frame()
91
+
92
+ #----------------------------------------------------------------------------
93
+ # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux.
94
+
95
+ class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer):
96
+ def __init__(self, *args, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self.mouse_wheel_multiplier = 1
99
+
100
+ def scroll_callback(self, window, x_offset, y_offset):
101
+ self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier
102
+
103
+ #----------------------------------------------------------------------------
gui_utils/text_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import functools
10
+ from typing import Optional
11
+
12
+ import dnnlib
13
+ import numpy as np
14
+ import PIL.Image
15
+ import PIL.ImageFont
16
+ import scipy.ndimage
17
+
18
+ from . import gl_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def get_default_font():
23
+ url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular
24
+ return dnnlib.util.open_url(url, return_filename=True)
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ @functools.lru_cache(maxsize=None)
29
+ def get_pil_font(font=None, size=32):
30
+ if font is None:
31
+ font = get_default_font()
32
+ return PIL.ImageFont.truetype(font=font, size=size)
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def get_array(string, *, dropshadow_radius: int=None, **kwargs):
37
+ if dropshadow_radius is not None:
38
+ offset_x = int(np.ceil(dropshadow_radius*2/3))
39
+ offset_y = int(np.ceil(dropshadow_radius*2/3))
40
+ return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
41
+ else:
42
+ return _get_array_priv(string, **kwargs)
43
+
44
+ @functools.lru_cache(maxsize=10000)
45
+ def _get_array_priv(
46
+ string: str, *,
47
+ size: int = 32,
48
+ max_width: Optional[int]=None,
49
+ max_height: Optional[int]=None,
50
+ min_size=10,
51
+ shrink_coef=0.8,
52
+ dropshadow_radius: int=None,
53
+ offset_x: int=None,
54
+ offset_y: int=None,
55
+ **kwargs
56
+ ):
57
+ cur_size = size
58
+ array = None
59
+ while True:
60
+ if dropshadow_radius is not None:
61
+ # separate implementation for dropshadow text rendering
62
+ array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs)
63
+ else:
64
+ array = _get_array_impl(string, size=cur_size, **kwargs)
65
+ height, width, _ = array.shape
66
+ if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size):
67
+ break
68
+ cur_size = max(int(cur_size * shrink_coef), min_size)
69
+ return array
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ @functools.lru_cache(maxsize=10000)
74
+ def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None):
75
+ pil_font = get_pil_font(font=font, size=size)
76
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
77
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
78
+ width = max(line.shape[1] for line in lines)
79
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
80
+ line_spacing = line_pad if line_pad is not None else size // 2
81
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
82
+ mask = np.concatenate(lines, axis=0)
83
+ alpha = mask
84
+ if outline > 0:
85
+ mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0)
86
+ alpha = mask.astype(np.float32) / 255
87
+ alpha = scipy.ndimage.gaussian_filter(alpha, outline)
88
+ alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp
89
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
90
+ alpha = np.maximum(alpha, mask)
91
+ return np.stack([mask, alpha], axis=-1)
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ @functools.lru_cache(maxsize=10000)
96
+ def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs):
97
+ assert (offset_x > 0) and (offset_y > 0)
98
+ pil_font = get_pil_font(font=font, size=size)
99
+ lines = [pil_font.getmask(line, 'L') for line in string.split('\n')]
100
+ lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines]
101
+ width = max(line.shape[1] for line in lines)
102
+ lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines]
103
+ line_spacing = line_pad if line_pad is not None else size // 2
104
+ lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:]
105
+ mask = np.concatenate(lines, axis=0)
106
+ alpha = mask
107
+
108
+ mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0)
109
+ alpha = mask.astype(np.float32) / 255
110
+ alpha = scipy.ndimage.gaussian_filter(alpha, radius)
111
+ alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4
112
+ alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8)
113
+ alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x]
114
+ alpha = np.maximum(alpha, mask)
115
+ return np.stack([mask, alpha], axis=-1)
116
+
117
+ #----------------------------------------------------------------------------
118
+
119
+ @functools.lru_cache(maxsize=10000)
120
+ def get_texture(string, bilinear=True, mipmap=True, **kwargs):
121
+ return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap)
122
+
123
+ #----------------------------------------------------------------------------
legacy.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Converting legacy network pickle into the new format."""
10
+
11
+ import click
12
+ import pickle
13
+ import re
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from torch_utils import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def load_network_pkl(f, force_fp16=False):
23
+ data = _LegacyUnpickler(f).load()
24
+
25
+ # Legacy TensorFlow pickle => convert.
26
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
27
+ tf_G, tf_D, tf_Gs = data
28
+ G = convert_tf_generator(tf_G)
29
+ D = convert_tf_discriminator(tf_D)
30
+ G_ema = convert_tf_generator(tf_Gs)
31
+ data = dict(G=G, D=D, G_ema=G_ema)
32
+
33
+ # Add missing fields.
34
+ if 'training_set_kwargs' not in data:
35
+ data['training_set_kwargs'] = None
36
+ if 'augment_pipe' not in data:
37
+ data['augment_pipe'] = None
38
+
39
+ # Validate contents.
40
+ assert isinstance(data['G'], torch.nn.Module)
41
+ assert isinstance(data['D'], torch.nn.Module)
42
+ assert isinstance(data['G_ema'], torch.nn.Module)
43
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
44
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
45
+
46
+ # Force FP16.
47
+ if force_fp16:
48
+ for key in ['G', 'D', 'G_ema']:
49
+ old = data[key]
50
+ kwargs = copy.deepcopy(old.init_kwargs)
51
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
52
+ fp16_kwargs.num_fp16_res = 4
53
+ fp16_kwargs.conv_clamp = 256
54
+ if kwargs != old.init_kwargs:
55
+ new = type(old)(**kwargs).eval().requires_grad_(False)
56
+ misc.copy_params_and_buffers(old, new, require_all=True)
57
+ data[key] = new
58
+ return data
59
+
60
+ #----------------------------------------------------------------------------
61
+
62
+ class _TFNetworkStub(dnnlib.EasyDict):
63
+ pass
64
+
65
+ class _LegacyUnpickler(pickle.Unpickler):
66
+ def find_class(self, module, name):
67
+ if module == 'dnnlib.tflib.network' and name == 'Network':
68
+ return _TFNetworkStub
69
+ return super().find_class(module, name)
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ def _collect_tf_params(tf_net):
74
+ # pylint: disable=protected-access
75
+ tf_params = dict()
76
+ def recurse(prefix, tf_net):
77
+ for name, value in tf_net.variables:
78
+ tf_params[prefix + name] = value
79
+ for name, comp in tf_net.components.items():
80
+ recurse(prefix + name + '/', comp)
81
+ recurse('', tf_net)
82
+ return tf_params
83
+
84
+ #----------------------------------------------------------------------------
85
+
86
+ def _populate_module_params(module, *patterns):
87
+ for name, tensor in misc.named_params_and_buffers(module):
88
+ found = False
89
+ value = None
90
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
91
+ match = re.fullmatch(pattern, name)
92
+ if match:
93
+ found = True
94
+ if value_fn is not None:
95
+ value = value_fn(*match.groups())
96
+ break
97
+ try:
98
+ assert found
99
+ if value is not None:
100
+ tensor.copy_(torch.from_numpy(np.array(value)))
101
+ except:
102
+ print(name, list(tensor.shape))
103
+ raise
104
+
105
+ #----------------------------------------------------------------------------
106
+
107
+ def convert_tf_generator(tf_G):
108
+ if tf_G.version < 4:
109
+ raise ValueError('TensorFlow pickle version too low')
110
+
111
+ # Collect kwargs.
112
+ tf_kwargs = tf_G.static_kwargs
113
+ known_kwargs = set()
114
+ def kwarg(tf_name, default=None, none=None):
115
+ known_kwargs.add(tf_name)
116
+ val = tf_kwargs.get(tf_name, default)
117
+ return val if val is not None else none
118
+
119
+ # Convert kwargs.
120
+ from training import networks_stylegan2
121
+ network_class = networks_stylegan2.Generator
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ channel_base = kwarg('fmap_base', 16384) * 2,
129
+ channel_max = kwarg('fmap_max', 512),
130
+ num_fp16_res = kwarg('num_fp16_res', 0),
131
+ conv_clamp = kwarg('conv_clamp', None),
132
+ architecture = kwarg('architecture', 'skip'),
133
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
134
+ use_noise = kwarg('use_noise', True),
135
+ activation = kwarg('nonlinearity', 'lrelu'),
136
+ mapping_kwargs = dnnlib.EasyDict(
137
+ num_layers = kwarg('mapping_layers', 8),
138
+ embed_features = kwarg('label_fmaps', None),
139
+ layer_features = kwarg('mapping_fmaps', None),
140
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
141
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
142
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
143
+ ),
144
+ )
145
+
146
+ # Check for unknown kwargs.
147
+ kwarg('truncation_psi')
148
+ kwarg('truncation_cutoff')
149
+ kwarg('style_mixing_prob')
150
+ kwarg('structure')
151
+ kwarg('conditioning')
152
+ kwarg('fused_modconv')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ G = network_class(**kwargs).eval().requires_grad_(False)
169
+ # pylint: disable=unnecessary-lambda
170
+ # pylint: disable=f-string-without-interpolation
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ r'.*\.act_filter', None,
203
+ )
204
+ return G
205
+
206
+ #----------------------------------------------------------------------------
207
+
208
+ def convert_tf_discriminator(tf_D):
209
+ if tf_D.version < 4:
210
+ raise ValueError('TensorFlow pickle version too low')
211
+
212
+ # Collect kwargs.
213
+ tf_kwargs = tf_D.static_kwargs
214
+ known_kwargs = set()
215
+ def kwarg(tf_name, default=None):
216
+ known_kwargs.add(tf_name)
217
+ return tf_kwargs.get(tf_name, default)
218
+
219
+ # Convert kwargs.
220
+ kwargs = dnnlib.EasyDict(
221
+ c_dim = kwarg('label_size', 0),
222
+ img_resolution = kwarg('resolution', 1024),
223
+ img_channels = kwarg('num_channels', 3),
224
+ architecture = kwarg('architecture', 'resnet'),
225
+ channel_base = kwarg('fmap_base', 16384) * 2,
226
+ channel_max = kwarg('fmap_max', 512),
227
+ num_fp16_res = kwarg('num_fp16_res', 0),
228
+ conv_clamp = kwarg('conv_clamp', None),
229
+ cmap_dim = kwarg('mapping_fmaps', None),
230
+ block_kwargs = dnnlib.EasyDict(
231
+ activation = kwarg('nonlinearity', 'lrelu'),
232
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
233
+ freeze_layers = kwarg('freeze_layers', 0),
234
+ ),
235
+ mapping_kwargs = dnnlib.EasyDict(
236
+ num_layers = kwarg('mapping_layers', 0),
237
+ embed_features = kwarg('mapping_fmaps', None),
238
+ layer_features = kwarg('mapping_fmaps', None),
239
+ activation = kwarg('nonlinearity', 'lrelu'),
240
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
241
+ ),
242
+ epilogue_kwargs = dnnlib.EasyDict(
243
+ mbstd_group_size = kwarg('mbstd_group_size', None),
244
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
245
+ activation = kwarg('nonlinearity', 'lrelu'),
246
+ ),
247
+ )
248
+
249
+ # Check for unknown kwargs.
250
+ kwarg('structure')
251
+ kwarg('conditioning')
252
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
253
+ if len(unknown_kwargs) > 0:
254
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
255
+
256
+ # Collect params.
257
+ tf_params = _collect_tf_params(tf_D)
258
+ for name, value in list(tf_params.items()):
259
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
260
+ if match:
261
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
262
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
263
+ kwargs.architecture = 'orig'
264
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
265
+
266
+ # Convert params.
267
+ from training import networks_stylegan2
268
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
269
+ # pylint: disable=unnecessary-lambda
270
+ # pylint: disable=f-string-without-interpolation
271
+ _populate_module_params(D,
272
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
273
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
274
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
275
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
276
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
277
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
278
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
279
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
280
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
281
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
282
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
283
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
284
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
285
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
286
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
287
+ r'.*\.resample_filter', None,
288
+ )
289
+ return D
290
+
291
+ #----------------------------------------------------------------------------
292
+
293
+ @click.command()
294
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
295
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
296
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
297
+ def convert_network_pickle(source, dest, force_fp16):
298
+ """Convert legacy network pickle into the native PyTorch format.
299
+
300
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
301
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
302
+
303
+ Example:
304
+
305
+ \b
306
+ python legacy.py \\
307
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
308
+ --dest=stylegan2-cat-config-f.pkl
309
+ """
310
+ print(f'Loading "{source}"...')
311
+ with dnnlib.util.open_url(source) as f:
312
+ data = load_network_pkl(f, force_fp16=force_fp16)
313
+ print(f'Saving "{dest}"...')
314
+ with open(dest, 'wb') as f:
315
+ pickle.dump(data, f)
316
+ print('Done.')
317
+
318
+ #----------------------------------------------------------------------------
319
+
320
+ if __name__ == "__main__":
321
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
322
+
323
+ #----------------------------------------------------------------------------
metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
metrics/equivariance.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
10
+ "Alias-Free Generative Adversarial Networks"."""
11
+
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import torch.fft
16
+ from torch_utils.ops import upfirdn2d
17
+ from . import metric_utils
18
+
19
+ #----------------------------------------------------------------------------
20
+ # Utilities.
21
+
22
+ def sinc(x):
23
+ y = (x * np.pi).abs()
24
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
25
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
26
+
27
+ def lanczos_window(x, a):
28
+ x = x.abs() / a
29
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
30
+
31
+ def rotation_matrix(angle):
32
+ angle = torch.as_tensor(angle).to(torch.float32)
33
+ mat = torch.eye(3, device=angle.device)
34
+ mat[0, 0] = angle.cos()
35
+ mat[0, 1] = angle.sin()
36
+ mat[1, 0] = -angle.sin()
37
+ mat[1, 1] = angle.cos()
38
+ return mat
39
+
40
+ #----------------------------------------------------------------------------
41
+ # Apply integer translation to a batch of 2D images. Corresponds to the
42
+ # operator T_x in Appendix E.1.
43
+
44
+ def apply_integer_translation(x, tx, ty):
45
+ _N, _C, H, W = x.shape
46
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
47
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
48
+ ix = tx.round().to(torch.int64)
49
+ iy = ty.round().to(torch.int64)
50
+
51
+ z = torch.zeros_like(x)
52
+ m = torch.zeros_like(x)
53
+ if abs(ix) < W and abs(iy) < H:
54
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
55
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
56
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
57
+ return z, m
58
+
59
+ #----------------------------------------------------------------------------
60
+ # Apply integer translation to a batch of 2D images. Corresponds to the
61
+ # operator T_x in Appendix E.2.
62
+
63
+ def apply_fractional_translation(x, tx, ty, a=3):
64
+ _N, _C, H, W = x.shape
65
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
66
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
67
+ ix = tx.floor().to(torch.int64)
68
+ iy = ty.floor().to(torch.int64)
69
+ fx = tx - ix
70
+ fy = ty - iy
71
+ b = a - 1
72
+
73
+ z = torch.zeros_like(x)
74
+ zx0 = max(ix - b, 0)
75
+ zy0 = max(iy - b, 0)
76
+ zx1 = min(ix + a, 0) + W
77
+ zy1 = min(iy + a, 0) + H
78
+ if zx0 < zx1 and zy0 < zy1:
79
+ taps = torch.arange(a * 2, device=x.device) - b
80
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
81
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
82
+ y = x
83
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
84
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
85
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
86
+ z[:, :, zy0:zy1, zx0:zx1] = y
87
+
88
+ m = torch.zeros_like(x)
89
+ mx0 = max(ix + a, 0)
90
+ my0 = max(iy + a, 0)
91
+ mx1 = min(ix - b, 0) + W
92
+ my1 = min(iy - b, 0) + H
93
+ if mx0 < mx1 and my0 < my1:
94
+ m[:, :, my0:my1, mx0:mx1] = 1
95
+ return z, m
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Construct an oriented low-pass filter that applies the appropriate
99
+ # bandlimit with respect to the input and output of the given affine 2D
100
+ # image transformation.
101
+
102
+ def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
103
+ assert a <= amax < aflt
104
+ mat = torch.as_tensor(mat).to(torch.float32)
105
+
106
+ # Construct 2D filter taps in input & output coordinate spaces.
107
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
108
+ yi, xi = torch.meshgrid(taps, taps)
109
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
110
+
111
+ # Convolution of two oriented 2D sinc filters.
112
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
113
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
114
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
115
+
116
+ # Convolution of two oriented 2D Lanczos windows.
117
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
118
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
119
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
120
+
121
+ # Construct windowed FIR filter.
122
+ f = f * w
123
+
124
+ # Finalize.
125
+ c = (aflt - amax) * up
126
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
127
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
128
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
129
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
130
+ return f
131
+
132
+ #----------------------------------------------------------------------------
133
+ # Apply the given affine transformation to a batch of 2D images.
134
+
135
+ def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
136
+ _N, _C, H, W = x.shape
137
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
138
+
139
+ # Construct filter.
140
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
141
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
142
+ p = f.shape[0] // 2
143
+
144
+ # Construct sampling grid.
145
+ theta = mat.inverse()
146
+ theta[:2, 2] *= 2
147
+ theta[0, 2] += 1 / up / W
148
+ theta[1, 2] += 1 / up / H
149
+ theta[0, :] *= W / (W + p / up * 2)
150
+ theta[1, :] *= H / (H + p / up * 2)
151
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
152
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
153
+
154
+ # Resample image.
155
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
156
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
157
+
158
+ # Form mask.
159
+ m = torch.zeros_like(y)
160
+ c = p * 2 + 1
161
+ m[:, :, c:-c, c:-c] = 1
162
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
163
+ return z, m
164
+
165
+ #----------------------------------------------------------------------------
166
+ # Apply fractional rotation to a batch of 2D images. Corresponds to the
167
+ # operator R_\alpha in Appendix E.3.
168
+
169
+ def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
170
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
171
+ mat = rotation_matrix(angle)
172
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
173
+
174
+ #----------------------------------------------------------------------------
175
+ # Modify the frequency content of a batch of 2D images as if they had undergo
176
+ # fractional rotation -- but without actually rotating them. Corresponds to
177
+ # the operator R^*_\alpha in Appendix E.3.
178
+
179
+ def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
180
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
181
+ mat = rotation_matrix(-angle)
182
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
183
+ y = upfirdn2d.filter2d(x=x, f=f)
184
+ m = torch.zeros_like(y)
185
+ c = f.shape[0] // 2
186
+ m[:, :, c:-c, c:-c] = 1
187
+ return y, m
188
+
189
+ #----------------------------------------------------------------------------
190
+ # Compute the selected equivariance metrics for the given generator.
191
+
192
+ def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
193
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
194
+
195
+ # Setup generator and labels.
196
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
197
+ I = torch.eye(3, device=opts.device)
198
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
199
+ if M is None:
200
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
201
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
202
+
203
+ # Sampling loop.
204
+ sums = None
205
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
206
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
207
+ progress.update(batch_start)
208
+ s = []
209
+
210
+ # Randomize noise buffers, if any.
211
+ for name, buf in G.named_buffers():
212
+ if name.endswith('.noise_const'):
213
+ buf.copy_(torch.randn_like(buf))
214
+
215
+ # Run mapping network.
216
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
217
+ c = next(c_iter)
218
+ ws = G.mapping(z=z, c=c)
219
+
220
+ # Generate reference image.
221
+ M[:] = I
222
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
223
+
224
+ # Integer translation (EQ-T).
225
+ if compute_eqt_int:
226
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
227
+ t = (t * G.img_resolution).round() / G.img_resolution
228
+ M[:] = I
229
+ M[:2, 2] = -t
230
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
231
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
232
+ s += [(ref - img).square() * mask, mask]
233
+
234
+ # Fractional translation (EQ-T_frac).
235
+ if compute_eqt_frac:
236
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
237
+ M[:] = I
238
+ M[:2, 2] = -t
239
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
240
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
241
+ s += [(ref - img).square() * mask, mask]
242
+
243
+ # Rotation (EQ-R).
244
+ if compute_eqr:
245
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
246
+ M[:] = rotation_matrix(-angle)
247
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
248
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
249
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
250
+ mask = ref_mask * pseudo_mask
251
+ s += [(ref - pseudo).square() * mask, mask]
252
+
253
+ # Accumulate results.
254
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
255
+ sums = sums + s if sums is not None else s
256
+ progress.update(num_samples)
257
+
258
+ # Compute PSNRs.
259
+ if opts.num_gpus > 1:
260
+ torch.distributed.all_reduce(sums)
261
+ sums = sums.cpu()
262
+ mses = sums[0::2] / sums[1::2]
263
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
264
+ psnrs = tuple(psnrs.numpy())
265
+ return psnrs[0] if len(psnrs) == 1 else psnrs
266
+
267
+ #----------------------------------------------------------------------------
metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Frechet Inception Distance (FID) from the paper
10
+ "GANs trained by a two time-scale update rule converge to a local Nash
11
+ equilibrium". Matches the original implementation by Heusel et al. at
12
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
+
14
+ import numpy as np
15
+ import scipy.linalg
16
+ from . import metric_utils
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def compute_fid(opts, max_real, num_gen):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
23
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
+
25
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28
+
29
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32
+
33
+ if opts.rank != 0:
34
+ return float('nan')
35
+
36
+ m = np.square(mu_gen - mu_real).sum()
37
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39
+ return float(fid)
40
+
41
+ #----------------------------------------------------------------------------
metrics/inception_score.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Inception Score (IS) from the paper "Improved techniques for training
10
+ GANs". Matches the original implementation by Salimans et al. at
11
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_is(opts, num_gen, num_splits):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
21
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
+
23
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ capture_all=True, max_items=num_gen).get_all()
26
+
27
+ if opts.rank != 0:
28
+ return float('nan'), float('nan')
29
+
30
+ scores = []
31
+ for i in range(num_splits):
32
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34
+ kl = np.mean(np.sum(kl, axis=1))
35
+ scores.append(np.exp(kl))
36
+ return float(np.mean(scores)), float(np.std(scores))
37
+
38
+ #----------------------------------------------------------------------------
metrics/kernel_inception_distance.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
+ GANs". Matches the original implementation by Binkowski et al. at
11
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
21
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
+
23
+ real_features = metric_utils.compute_feature_stats_for_dataset(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26
+
27
+ gen_features = metric_utils.compute_feature_stats_for_generator(
28
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
+
31
+ if opts.rank != 0:
32
+ return float('nan')
33
+
34
+ n = real_features.shape[1]
35
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
+ t = 0
37
+ for _subset_idx in range(num_subsets):
38
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
+ b = (x @ y.T / n + 1) ** 3
42
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
+ kid = t / num_subsets / m
44
+ return float(kid)
45
+
46
+ #----------------------------------------------------------------------------
metrics/metric_main.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Main API for computing and reporting quality metrics."""
10
+
11
+ import os
12
+ import time
13
+ import json
14
+ import torch
15
+ import dnnlib
16
+
17
+ from . import metric_utils
18
+ from . import frechet_inception_distance
19
+ from . import kernel_inception_distance
20
+ from . import precision_recall
21
+ from . import perceptual_path_length
22
+ from . import inception_score
23
+ from . import equivariance
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _metric_dict = dict() # name => fn
28
+
29
+ def register_metric(fn):
30
+ assert callable(fn)
31
+ _metric_dict[fn.__name__] = fn
32
+ return fn
33
+
34
+ def is_valid_metric(metric):
35
+ return metric in _metric_dict
36
+
37
+ def list_valid_metrics():
38
+ return list(_metric_dict.keys())
39
+
40
+ #----------------------------------------------------------------------------
41
+
42
+ def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
43
+ assert is_valid_metric(metric)
44
+ opts = metric_utils.MetricOptions(**kwargs)
45
+
46
+ # Calculate.
47
+ start_time = time.time()
48
+ results = _metric_dict[metric](opts)
49
+ total_time = time.time() - start_time
50
+
51
+ # Broadcast results.
52
+ for key, value in list(results.items()):
53
+ if opts.num_gpus > 1:
54
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
55
+ torch.distributed.broadcast(tensor=value, src=0)
56
+ value = float(value.cpu())
57
+ results[key] = value
58
+
59
+ # Decorate with metadata.
60
+ return dnnlib.EasyDict(
61
+ results = dnnlib.EasyDict(results),
62
+ metric = metric,
63
+ total_time = total_time,
64
+ total_time_str = dnnlib.util.format_time(total_time),
65
+ num_gpus = opts.num_gpus,
66
+ )
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
71
+ metric = result_dict['metric']
72
+ assert is_valid_metric(metric)
73
+ if run_dir is not None and snapshot_pkl is not None:
74
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
75
+
76
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
77
+ print(jsonl_line)
78
+ if run_dir is not None and os.path.isdir(run_dir):
79
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
80
+ f.write(jsonl_line + '\n')
81
+
82
+ #----------------------------------------------------------------------------
83
+ # Recommended metrics.
84
+
85
+ @register_metric
86
+ def fid50k_full(opts):
87
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
88
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
89
+ return dict(fid50k_full=fid)
90
+
91
+ @register_metric
92
+ def kid50k_full(opts):
93
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
94
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
95
+ return dict(kid50k_full=kid)
96
+
97
+ @register_metric
98
+ def pr50k3_full(opts):
99
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
100
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
101
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
102
+
103
+ @register_metric
104
+ def ppl2_wend(opts):
105
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
106
+ return dict(ppl2_wend=ppl)
107
+
108
+ @register_metric
109
+ def eqt50k_int(opts):
110
+ opts.G_kwargs.update(force_fp32=True)
111
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
112
+ return dict(eqt50k_int=psnr)
113
+
114
+ @register_metric
115
+ def eqt50k_frac(opts):
116
+ opts.G_kwargs.update(force_fp32=True)
117
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
118
+ return dict(eqt50k_frac=psnr)
119
+
120
+ @register_metric
121
+ def eqr50k(opts):
122
+ opts.G_kwargs.update(force_fp32=True)
123
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
124
+ return dict(eqr50k=psnr)
125
+
126
+ #----------------------------------------------------------------------------
127
+ # Legacy metrics.
128
+
129
+ @register_metric
130
+ def fid50k(opts):
131
+ opts.dataset_kwargs.update(max_size=None)
132
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
133
+ return dict(fid50k=fid)
134
+
135
+ @register_metric
136
+ def kid50k(opts):
137
+ opts.dataset_kwargs.update(max_size=None)
138
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
139
+ return dict(kid50k=kid)
140
+
141
+ @register_metric
142
+ def pr50k3(opts):
143
+ opts.dataset_kwargs.update(max_size=None)
144
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
145
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
146
+
147
+ @register_metric
148
+ def is50k(opts):
149
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
150
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
151
+ return dict(is50k_mean=mean, is50k_std=std)
152
+
153
+ #----------------------------------------------------------------------------
metrics/metric_utils.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utilities used internally by the quality metrics."""
10
+
11
+ import os
12
+ import time
13
+ import hashlib
14
+ import pickle
15
+ import copy
16
+ import uuid
17
+ import numpy as np
18
+ import torch
19
+ import dnnlib
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ class MetricOptions:
24
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
25
+ assert 0 <= rank < num_gpus
26
+ self.G = G
27
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
28
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
29
+ self.num_gpus = num_gpus
30
+ self.rank = rank
31
+ self.device = device if device is not None else torch.device('cuda', rank)
32
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
33
+ self.cache = cache
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ _feature_detector_cache = dict()
38
+
39
+ def get_feature_detector_name(url):
40
+ return os.path.splitext(url.split('/')[-1])[0]
41
+
42
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
43
+ assert 0 <= rank < num_gpus
44
+ key = (url, device)
45
+ if key not in _feature_detector_cache:
46
+ is_leader = (rank == 0)
47
+ if not is_leader and num_gpus > 1:
48
+ torch.distributed.barrier() # leader goes first
49
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
50
+ _feature_detector_cache[key] = pickle.load(f).to(device)
51
+ if is_leader and num_gpus > 1:
52
+ torch.distributed.barrier() # others follow
53
+ return _feature_detector_cache[key]
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def iterate_random_labels(opts, batch_size):
58
+ if opts.G.c_dim == 0:
59
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
60
+ while True:
61
+ yield c
62
+ else:
63
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
64
+ while True:
65
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
66
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
67
+ yield c
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ class FeatureStats:
72
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
73
+ self.capture_all = capture_all
74
+ self.capture_mean_cov = capture_mean_cov
75
+ self.max_items = max_items
76
+ self.num_items = 0
77
+ self.num_features = None
78
+ self.all_features = None
79
+ self.raw_mean = None
80
+ self.raw_cov = None
81
+
82
+ def set_num_features(self, num_features):
83
+ if self.num_features is not None:
84
+ assert num_features == self.num_features
85
+ else:
86
+ self.num_features = num_features
87
+ self.all_features = []
88
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
89
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
90
+
91
+ def is_full(self):
92
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
93
+
94
+ def append(self, x):
95
+ x = np.asarray(x, dtype=np.float32)
96
+ assert x.ndim == 2
97
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
98
+ if self.num_items >= self.max_items:
99
+ return
100
+ x = x[:self.max_items - self.num_items]
101
+
102
+ self.set_num_features(x.shape[1])
103
+ self.num_items += x.shape[0]
104
+ if self.capture_all:
105
+ self.all_features.append(x)
106
+ if self.capture_mean_cov:
107
+ x64 = x.astype(np.float64)
108
+ self.raw_mean += x64.sum(axis=0)
109
+ self.raw_cov += x64.T @ x64
110
+
111
+ def append_torch(self, x, num_gpus=1, rank=0):
112
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
113
+ assert 0 <= rank < num_gpus
114
+ if num_gpus > 1:
115
+ ys = []
116
+ for src in range(num_gpus):
117
+ y = x.clone()
118
+ torch.distributed.broadcast(y, src=src)
119
+ ys.append(y)
120
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
121
+ self.append(x.cpu().numpy())
122
+
123
+ def get_all(self):
124
+ assert self.capture_all
125
+ return np.concatenate(self.all_features, axis=0)
126
+
127
+ def get_all_torch(self):
128
+ return torch.from_numpy(self.get_all())
129
+
130
+ def get_mean_cov(self):
131
+ assert self.capture_mean_cov
132
+ mean = self.raw_mean / self.num_items
133
+ cov = self.raw_cov / self.num_items
134
+ cov = cov - np.outer(mean, mean)
135
+ return mean, cov
136
+
137
+ def save(self, pkl_file):
138
+ with open(pkl_file, 'wb') as f:
139
+ pickle.dump(self.__dict__, f)
140
+
141
+ @staticmethod
142
+ def load(pkl_file):
143
+ with open(pkl_file, 'rb') as f:
144
+ s = dnnlib.EasyDict(pickle.load(f))
145
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
146
+ obj.__dict__.update(s)
147
+ return obj
148
+
149
+ #----------------------------------------------------------------------------
150
+
151
+ class ProgressMonitor:
152
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
153
+ self.tag = tag
154
+ self.num_items = num_items
155
+ self.verbose = verbose
156
+ self.flush_interval = flush_interval
157
+ self.progress_fn = progress_fn
158
+ self.pfn_lo = pfn_lo
159
+ self.pfn_hi = pfn_hi
160
+ self.pfn_total = pfn_total
161
+ self.start_time = time.time()
162
+ self.batch_time = self.start_time
163
+ self.batch_items = 0
164
+ if self.progress_fn is not None:
165
+ self.progress_fn(self.pfn_lo, self.pfn_total)
166
+
167
+ def update(self, cur_items):
168
+ assert (self.num_items is None) or (cur_items <= self.num_items)
169
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
170
+ return
171
+ cur_time = time.time()
172
+ total_time = cur_time - self.start_time
173
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
174
+ if (self.verbose) and (self.tag is not None):
175
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
176
+ self.batch_time = cur_time
177
+ self.batch_items = cur_items
178
+
179
+ if (self.progress_fn is not None) and (self.num_items is not None):
180
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
181
+
182
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
183
+ return ProgressMonitor(
184
+ tag = tag,
185
+ num_items = num_items,
186
+ flush_interval = flush_interval,
187
+ verbose = self.verbose,
188
+ progress_fn = self.progress_fn,
189
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
190
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
191
+ pfn_total = self.pfn_total,
192
+ )
193
+
194
+ #----------------------------------------------------------------------------
195
+
196
+ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
197
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
198
+ if data_loader_kwargs is None:
199
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
200
+
201
+ # Try to lookup from cache.
202
+ cache_file = None
203
+ if opts.cache:
204
+ # Choose cache file name.
205
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
206
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
207
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
208
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
209
+
210
+ # Check if the file exists (all processes must agree).
211
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
212
+ if opts.num_gpus > 1:
213
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
214
+ torch.distributed.broadcast(tensor=flag, src=0)
215
+ flag = (float(flag.cpu()) != 0)
216
+
217
+ # Load.
218
+ if flag:
219
+ return FeatureStats.load(cache_file)
220
+
221
+ # Initialize.
222
+ num_items = len(dataset)
223
+ if max_items is not None:
224
+ num_items = min(num_items, max_items)
225
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
226
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
227
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
228
+
229
+ # Main loop.
230
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
231
+ for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
232
+ if images.shape[1] == 1:
233
+ images = images.repeat([1, 3, 1, 1])
234
+ features = detector(images.to(opts.device), **detector_kwargs)
235
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
236
+ progress.update(stats.num_items)
237
+
238
+ # Save to cache.
239
+ if cache_file is not None and opts.rank == 0:
240
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
241
+ temp_file = cache_file + '.' + uuid.uuid4().hex
242
+ stats.save(temp_file)
243
+ os.replace(temp_file, cache_file) # atomic
244
+ return stats
245
+
246
+ #----------------------------------------------------------------------------
247
+
248
+ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
249
+ if batch_gen is None:
250
+ batch_gen = min(batch_size, 4)
251
+ assert batch_size % batch_gen == 0
252
+
253
+ # Setup generator and labels.
254
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
255
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
256
+
257
+ # Initialize.
258
+ stats = FeatureStats(**stats_kwargs)
259
+ assert stats.max_items is not None
260
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
261
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
262
+
263
+ # Main loop.
264
+ while not stats.is_full():
265
+ images = []
266
+ for _i in range(batch_size // batch_gen):
267
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
268
+ img = G(z=z, c=next(c_iter), **opts.G_kwargs)
269
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
270
+ images.append(img)
271
+ images = torch.cat(images)
272
+ if images.shape[1] == 1:
273
+ images = images.repeat([1, 3, 1, 1])
274
+ features = detector(images, **detector_kwargs)
275
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
276
+ progress.update(stats.num_items)
277
+ return stats
278
+
279
+ #----------------------------------------------------------------------------
metrics/perceptual_path_length.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
+ Architecture for Generative Adversarial Networks". Matches the original
11
+ implementation by Karras et al. at
12
+ https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
+
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ from . import metric_utils
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ # Spherical interpolation of a batch of vectors.
22
+ def slerp(a, b, t):
23
+ a = a / a.norm(dim=-1, keepdim=True)
24
+ b = b / b.norm(dim=-1, keepdim=True)
25
+ d = (a * b).sum(dim=-1, keepdim=True)
26
+ p = t * torch.acos(d)
27
+ c = b - d * a
28
+ c = c / c.norm(dim=-1, keepdim=True)
29
+ d = a * torch.cos(p) + c * torch.sin(p)
30
+ d = d / d.norm(dim=-1, keepdim=True)
31
+ return d
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ class PPLSampler(torch.nn.Module):
36
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
37
+ assert space in ['z', 'w']
38
+ assert sampling in ['full', 'end']
39
+ super().__init__()
40
+ self.G = copy.deepcopy(G)
41
+ self.G_kwargs = G_kwargs
42
+ self.epsilon = epsilon
43
+ self.space = space
44
+ self.sampling = sampling
45
+ self.crop = crop
46
+ self.vgg16 = copy.deepcopy(vgg16)
47
+
48
+ def forward(self, c):
49
+ # Generate random latents and interpolation t-values.
50
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
51
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
52
+
53
+ # Interpolate in W or Z.
54
+ if self.space == 'w':
55
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
56
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
57
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
58
+ else: # space == 'z'
59
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
60
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
61
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
62
+
63
+ # Randomize noise buffers.
64
+ for name, buf in self.G.named_buffers():
65
+ if name.endswith('.noise_const'):
66
+ buf.copy_(torch.randn_like(buf))
67
+
68
+ # Generate images.
69
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
70
+
71
+ # Center crop.
72
+ if self.crop:
73
+ assert img.shape[2] == img.shape[3]
74
+ c = img.shape[2] // 8
75
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
76
+
77
+ # Downsample to 256x256.
78
+ factor = self.G.img_resolution // 256
79
+ if factor > 1:
80
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
81
+
82
+ # Scale dynamic range from [-1,1] to [0,255].
83
+ img = (img + 1) * (255 / 2)
84
+ if self.G.img_channels == 1:
85
+ img = img.repeat([1, 3, 1, 1])
86
+
87
+ # Evaluate differential LPIPS.
88
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
89
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
90
+ return dist
91
+
92
+ #----------------------------------------------------------------------------
93
+
94
+ def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
95
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
96
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
97
+
98
+ # Setup sampler and labels.
99
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
100
+ sampler.eval().requires_grad_(False).to(opts.device)
101
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
102
+
103
+ # Sampling loop.
104
+ dist = []
105
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
106
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
107
+ progress.update(batch_start)
108
+ x = sampler(next(c_iter))
109
+ for src in range(opts.num_gpus):
110
+ y = x.clone()
111
+ if opts.num_gpus > 1:
112
+ torch.distributed.broadcast(y, src=src)
113
+ dist.append(y)
114
+ progress.update(num_samples)
115
+
116
+ # Compute PPL.
117
+ if opts.rank != 0:
118
+ return float('nan')
119
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
120
+ lo = np.percentile(dist, 1, interpolation='lower')
121
+ hi = np.percentile(dist, 99, interpolation='higher')
122
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
123
+ return float(ppl)
124
+
125
+ #----------------------------------------------------------------------------
metrics/precision_recall.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Precision/Recall (PR) from the paper "Improved Precision and Recall
10
+ Metric for Assessing Generative Models". Matches the original implementation
11
+ by Kynkaanniemi et al. at
12
+ https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13
+
14
+ import torch
15
+ from . import metric_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20
+ assert 0 <= rank < num_gpus
21
+ num_cols = col_features.shape[0]
22
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24
+ dist_batches = []
25
+ for col_batch in col_batches[rank :: num_gpus]:
26
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27
+ for src in range(num_gpus):
28
+ dist_broadcast = dist_batch.clone()
29
+ if num_gpus > 1:
30
+ torch.distributed.broadcast(dist_broadcast, src=src)
31
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
38
+ detector_kwargs = dict(return_features=True)
39
+
40
+ real_features = metric_utils.compute_feature_stats_for_dataset(
41
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43
+
44
+ gen_features = metric_utils.compute_feature_stats_for_generator(
45
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47
+
48
+ results = dict()
49
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50
+ kth = []
51
+ for manifold_batch in manifold.split(row_batch_size):
52
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54
+ kth = torch.cat(kth) if opts.rank == 0 else None
55
+ pred = []
56
+ for probes_batch in probes.split(row_batch_size):
57
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60
+ return results['precision'], results['recall']
61
+
62
+ #----------------------------------------------------------------------------
torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/custom_ops.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import glob
10
+ import hashlib
11
+ import importlib
12
+ import os
13
+ import re
14
+ import shutil
15
+ import uuid
16
+
17
+ import torch
18
+ import torch.utils.cpp_extension
19
+ from torch.utils.file_baton import FileBaton
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25
+
26
+ #----------------------------------------------------------------------------
27
+ # Internal helper funcs.
28
+
29
+ def _find_compiler_bindir():
30
+ patterns = [
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
35
+ ]
36
+ for pattern in patterns:
37
+ matches = sorted(glob.glob(pattern))
38
+ if len(matches):
39
+ return matches[-1]
40
+ return None
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def _get_mangled_gpu_name():
45
+ name = torch.cuda.get_device_name().lower()
46
+ out = []
47
+ for c in name:
48
+ if re.match('[a-z0-9_-]+', c):
49
+ out.append(c)
50
+ else:
51
+ out.append('-')
52
+ return ''.join(out)
53
+
54
+ #----------------------------------------------------------------------------
55
+ # Main entry point for compiling and loading C++/CUDA plugins.
56
+
57
+ _cached_plugins = dict()
58
+
59
+ def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60
+ assert verbosity in ['none', 'brief', 'full']
61
+ if headers is None:
62
+ headers = []
63
+ if source_dir is not None:
64
+ sources = [os.path.join(source_dir, fname) for fname in sources]
65
+ headers = [os.path.join(source_dir, fname) for fname in headers]
66
+
67
+ # Already cached?
68
+ if module_name in _cached_plugins:
69
+ return _cached_plugins[module_name]
70
+
71
+ # Print status.
72
+ if verbosity == 'full':
73
+ print(f'Setting up PyTorch plugin "{module_name}"...')
74
+ elif verbosity == 'brief':
75
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76
+ verbose_build = (verbosity == 'full')
77
+
78
+ # Compile and load.
79
+ try: # pylint: disable=too-many-nested-blocks
80
+ # Make sure we can find the necessary compiler binaries.
81
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82
+ compiler_bindir = _find_compiler_bindir()
83
+ if compiler_bindir is None:
84
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85
+ os.environ['PATH'] += ';' + compiler_bindir
86
+
87
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88
+ # break the build or unnecessarily restrict what's available to nvcc.
89
+ # Unset it to let nvcc decide based on what's available on the
90
+ # machine.
91
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92
+
93
+ # Incremental build md5sum trickery. Copies all the input source files
94
+ # into a cached build directory under a combined md5 digest of the input
95
+ # source files. Copying is done only if the combined digest has changed.
96
+ # This keeps input file timestamps and filenames the same as in previous
97
+ # extension builds, allowing for fast incremental rebuilds.
98
+ #
99
+ # This optimization is done only in case all the source files reside in
100
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101
+ # environment variable is set (we take this as a signal that the user
102
+ # actually cares about this.)
103
+ #
104
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105
+ # around the *.cu dependency bug in ninja config.
106
+ #
107
+ all_source_files = sorted(sources + headers)
108
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110
+
111
+ # Compute combined hash digest for all source files.
112
+ hash_md5 = hashlib.md5()
113
+ for src in all_source_files:
114
+ with open(src, 'rb') as f:
115
+ hash_md5.update(f.read())
116
+
117
+ # Select cached build directory name.
118
+ source_digest = hash_md5.hexdigest()
119
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121
+
122
+ if not os.path.isdir(cached_build_dir):
123
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124
+ os.makedirs(tmpdir)
125
+ for src in all_source_files:
126
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127
+ try:
128
+ os.replace(tmpdir, cached_build_dir) # atomic
129
+ except OSError:
130
+ # source directory already exists, delete tmpdir and its contents.
131
+ shutil.rmtree(tmpdir)
132
+ if not os.path.isdir(cached_build_dir): raise
133
+
134
+ # Compile.
135
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
138
+ else:
139
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140
+
141
+ # Load.
142
+ module = importlib.import_module(module_name)
143
+
144
+ except:
145
+ if verbosity == 'brief':
146
+ print('Failed!')
147
+ raise
148
+
149
+ # Print status and add to cache dict.
150
+ if verbosity == 'full':
151
+ print(f'Done setting up PyTorch plugin "{module_name}".')
152
+ elif verbosity == 'brief':
153
+ print('Done.')
154
+ _cached_plugins[module_name] = module
155
+ return module
156
+
157
+ #----------------------------------------------------------------------------
torch_utils/misc.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
+ value = np.asarray(value)
24
+ if shape is not None:
25
+ shape = tuple(shape)
26
+ if dtype is None:
27
+ dtype = torch.get_default_dtype()
28
+ if device is None:
29
+ device = torch.device('cpu')
30
+ if memory_format is None:
31
+ memory_format = torch.contiguous_format
32
+
33
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
+ tensor = _constant_cache.get(key, None)
35
+ if tensor is None:
36
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
+ if shape is not None:
38
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
+ tensor = tensor.contiguous(memory_format=memory_format)
40
+ _constant_cache[key] = tensor
41
+ return tensor
42
+
43
+ #----------------------------------------------------------------------------
44
+ # Replace NaN/Inf with specified numerical values.
45
+
46
+ try:
47
+ nan_to_num = torch.nan_to_num # 1.8.0a0
48
+ except AttributeError:
49
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
+ assert isinstance(input, torch.Tensor)
51
+ if posinf is None:
52
+ posinf = torch.finfo(input.dtype).max
53
+ if neginf is None:
54
+ neginf = torch.finfo(input.dtype).min
55
+ assert nan == 0
56
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
+
58
+ #----------------------------------------------------------------------------
59
+ # Symbolic assert.
60
+
61
+ try:
62
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
+ except AttributeError:
64
+ symbolic_assert = torch.Assert # 1.7.0
65
+
66
+ #----------------------------------------------------------------------------
67
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
68
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
69
+
70
+ @contextlib.contextmanager
71
+ def suppress_tracer_warnings():
72
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
73
+ warnings.filters.insert(0, flt)
74
+ yield
75
+ warnings.filters.remove(flt)
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Assert that the shape of a tensor matches the given list of integers.
79
+ # None indicates that the size of a dimension is allowed to vary.
80
+ # Performs symbolic assertion when used in torch.jit.trace().
81
+
82
+ def assert_shape(tensor, ref_shape):
83
+ if tensor.ndim != len(ref_shape):
84
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
+ if ref_size is None:
87
+ pass
88
+ elif isinstance(ref_size, torch.Tensor):
89
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91
+ elif isinstance(size, torch.Tensor):
92
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94
+ elif size != ref_size:
95
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Function decorator that calls torch.autograd.profiler.record_function().
99
+
100
+ def profiled_function(fn):
101
+ def decorator(*args, **kwargs):
102
+ with torch.autograd.profiler.record_function(fn.__name__):
103
+ return fn(*args, **kwargs)
104
+ decorator.__name__ = fn.__name__
105
+ return decorator
106
+
107
+ #----------------------------------------------------------------------------
108
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
109
+ # indefinitely, shuffling items as it goes.
110
+
111
+ class InfiniteSampler(torch.utils.data.Sampler):
112
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113
+ assert len(dataset) > 0
114
+ assert num_replicas > 0
115
+ assert 0 <= rank < num_replicas
116
+ assert 0 <= window_size <= 1
117
+ super().__init__(dataset)
118
+ self.dataset = dataset
119
+ self.rank = rank
120
+ self.num_replicas = num_replicas
121
+ self.shuffle = shuffle
122
+ self.seed = seed
123
+ self.window_size = window_size
124
+
125
+ def __iter__(self):
126
+ order = np.arange(len(self.dataset))
127
+ rnd = None
128
+ window = 0
129
+ if self.shuffle:
130
+ rnd = np.random.RandomState(self.seed)
131
+ rnd.shuffle(order)
132
+ window = int(np.rint(order.size * self.window_size))
133
+
134
+ idx = 0
135
+ while True:
136
+ i = idx % order.size
137
+ if idx % self.num_replicas == self.rank:
138
+ yield order[i]
139
+ if window >= 2:
140
+ j = (i - rnd.randint(window)) % order.size
141
+ order[i], order[j] = order[j], order[i]
142
+ idx += 1
143
+
144
+ #----------------------------------------------------------------------------
145
+ # Utilities for operating with torch.nn.Module parameters and buffers.
146
+
147
+ def params_and_buffers(module):
148
+ assert isinstance(module, torch.nn.Module)
149
+ return list(module.parameters()) + list(module.buffers())
150
+
151
+ def named_params_and_buffers(module):
152
+ assert isinstance(module, torch.nn.Module)
153
+ return list(module.named_parameters()) + list(module.named_buffers())
154
+
155
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
156
+ assert isinstance(src_module, torch.nn.Module)
157
+ assert isinstance(dst_module, torch.nn.Module)
158
+ src_tensors = dict(named_params_and_buffers(src_module))
159
+ for name, tensor in named_params_and_buffers(dst_module):
160
+ assert (name in src_tensors) or (not require_all)
161
+ if name in src_tensors:
162
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
163
+
164
+ #----------------------------------------------------------------------------
165
+ # Context manager for easily enabling/disabling DistributedDataParallel
166
+ # synchronization.
167
+
168
+ @contextlib.contextmanager
169
+ def ddp_sync(module, sync):
170
+ assert isinstance(module, torch.nn.Module)
171
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172
+ yield
173
+ else:
174
+ with module.no_sync():
175
+ yield
176
+
177
+ #----------------------------------------------------------------------------
178
+ # Check DistributedDataParallel consistency across processes.
179
+
180
+ def check_ddp_consistency(module, ignore_regex=None):
181
+ assert isinstance(module, torch.nn.Module)
182
+ for name, tensor in named_params_and_buffers(module):
183
+ fullname = type(module).__name__ + '.' + name
184
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185
+ continue
186
+ tensor = tensor.detach()
187
+ if tensor.is_floating_point():
188
+ tensor = nan_to_num(tensor)
189
+ other = tensor.clone()
190
+ torch.distributed.broadcast(tensor=other, src=0)
191
+ assert (tensor == other).all(), fullname
192
+
193
+ #----------------------------------------------------------------------------
194
+ # Print summary table of module hierarchy.
195
+
196
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197
+ assert isinstance(module, torch.nn.Module)
198
+ assert not isinstance(module, torch.jit.ScriptModule)
199
+ assert isinstance(inputs, (tuple, list))
200
+
201
+ # Register hooks.
202
+ entries = []
203
+ nesting = [0]
204
+ def pre_hook(_mod, _inputs):
205
+ nesting[0] += 1
206
+ def post_hook(mod, _inputs, outputs):
207
+ nesting[0] -= 1
208
+ if nesting[0] <= max_nesting:
209
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214
+
215
+ # Run module.
216
+ outputs = module(*inputs)
217
+ for hook in hooks:
218
+ hook.remove()
219
+
220
+ # Identify unique outputs, parameters, and buffers.
221
+ tensors_seen = set()
222
+ for e in entries:
223
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227
+
228
+ # Filter out redundant entries.
229
+ if skip_redundant:
230
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231
+
232
+ # Construct table.
233
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234
+ rows += [['---'] * len(rows[0])]
235
+ param_total = 0
236
+ buffer_total = 0
237
+ submodule_names = {mod: name for name, mod in module.named_modules()}
238
+ for e in entries:
239
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
240
+ param_size = sum(t.numel() for t in e.unique_params)
241
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
242
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
243
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244
+ rows += [[
245
+ name + (':0' if len(e.outputs) >= 2 else ''),
246
+ str(param_size) if param_size else '-',
247
+ str(buffer_size) if buffer_size else '-',
248
+ (output_shapes + ['-'])[0],
249
+ (output_dtypes + ['-'])[0],
250
+ ]]
251
+ for idx in range(1, len(e.outputs)):
252
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253
+ param_total += param_size
254
+ buffer_total += buffer_size
255
+ rows += [['---'] * len(rows[0])]
256
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257
+
258
+ # Print table.
259
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260
+ print()
261
+ for row in rows:
262
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263
+ print()
264
+ return outputs
265
+
266
+ #----------------------------------------------------------------------------
torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ import dnnlib
15
+
16
+ from .. import custom_ops
17
+ from .. import misc
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ activation_funcs = {
22
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
23
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
24
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
25
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
26
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
27
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
28
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
29
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
30
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ _plugin = None
36
+ _null_tensor = torch.empty([0])
37
+
38
+ def _init():
39
+ global _plugin
40
+ if _plugin is None:
41
+ _plugin = custom_ops.get_plugin(
42
+ module_name='bias_act_plugin',
43
+ sources=['bias_act.cpp', 'bias_act.cu'],
44
+ headers=['bias_act.h'],
45
+ source_dir=os.path.dirname(__file__),
46
+ extra_cuda_cflags=['--use_fast_math'],
47
+ )
48
+ return True
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
53
+ r"""Fused bias and activation function.
54
+
55
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
56
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
57
+ the fused op is considerably more efficient than performing the same calculation
58
+ using standard PyTorch ops. It supports first and second order gradients,
59
+ but not third order gradients.
60
+
61
+ Args:
62
+ x: Input activation tensor. Can be of any shape.
63
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
64
+ as `x`. The shape must be known, and it must match the dimension of `x`
65
+ corresponding to `dim`.
66
+ dim: The dimension in `x` corresponding to the elements of `b`.
67
+ The value of `dim` is ignored if `b` is not specified.
68
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
69
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
70
+ See `activation_funcs` for a full list. `None` is not allowed.
71
+ alpha: Shape parameter for the activation function, or `None` to use the default.
72
+ gain: Scaling factor for the output tensor, or `None` to use default.
73
+ See `activation_funcs` for the default scaling of each activation function.
74
+ If unsure, consider specifying 1.
75
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
76
+ the clamping (default).
77
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
78
+
79
+ Returns:
80
+ Tensor of the same shape and datatype as `x`.
81
+ """
82
+ assert isinstance(x, torch.Tensor)
83
+ assert impl in ['ref', 'cuda']
84
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
85
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ @misc.profiled_function
91
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
92
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
93
+ """
94
+ assert isinstance(x, torch.Tensor)
95
+ assert clamp is None or clamp >= 0
96
+ spec = activation_funcs[act]
97
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
98
+ gain = float(gain if gain is not None else spec.def_gain)
99
+ clamp = float(clamp if clamp is not None else -1)
100
+
101
+ # Add bias.
102
+ if b is not None:
103
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
104
+ assert 0 <= dim < x.ndim
105
+ assert b.shape[0] == x.shape[dim]
106
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
107
+
108
+ # Evaluate activation function.
109
+ alpha = float(alpha)
110
+ x = spec.func(x, alpha=alpha)
111
+
112
+ # Scale by gain.
113
+ gain = float(gain)
114
+ if gain != 1:
115
+ x = x * gain
116
+
117
+ # Clamp.
118
+ if clamp >= 0:
119
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
120
+ return x
121
+
122
+ #----------------------------------------------------------------------------
123
+
124
+ _bias_act_cuda_cache = dict()
125
+
126
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
127
+ """Fast CUDA implementation of `bias_act()` using custom ops.
128
+ """
129
+ # Parse arguments.
130
+ assert clamp is None or clamp >= 0
131
+ spec = activation_funcs[act]
132
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
133
+ gain = float(gain if gain is not None else spec.def_gain)
134
+ clamp = float(clamp if clamp is not None else -1)
135
+
136
+ # Lookup from cache.
137
+ key = (dim, act, alpha, gain, clamp)
138
+ if key in _bias_act_cuda_cache:
139
+ return _bias_act_cuda_cache[key]
140
+
141
+ # Forward op.
142
+ class BiasActCuda(torch.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
145
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
146
+ x = x.contiguous(memory_format=ctx.memory_format)
147
+ b = b.contiguous() if b is not None else _null_tensor
148
+ y = x
149
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
150
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
151
+ ctx.save_for_backward(
152
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
153
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
154
+ y if 'y' in spec.ref else _null_tensor)
155
+ return y
156
+
157
+ @staticmethod
158
+ def backward(ctx, dy): # pylint: disable=arguments-differ
159
+ dy = dy.contiguous(memory_format=ctx.memory_format)
160
+ x, b, y = ctx.saved_tensors
161
+ dx = None
162
+ db = None
163
+
164
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
165
+ dx = dy
166
+ if act != 'linear' or gain != 1 or clamp >= 0:
167
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
168
+
169
+ if ctx.needs_input_grad[1]:
170
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
171
+
172
+ return dx, db
173
+
174
+ # Backward op.
175
+ class BiasActCudaGrad(torch.autograd.Function):
176
+ @staticmethod
177
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
178
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
179
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
180
+ ctx.save_for_backward(
181
+ dy if spec.has_2nd_grad else _null_tensor,
182
+ x, b, y)
183
+ return dx
184
+
185
+ @staticmethod
186
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
187
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
188
+ dy, x, b, y = ctx.saved_tensors
189
+ d_dy = None
190
+ d_x = None
191
+ d_b = None
192
+ d_y = None
193
+
194
+ if ctx.needs_input_grad[0]:
195
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
196
+
197
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
198
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
199
+
200
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
201
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
202
+
203
+ return d_dy, d_x, d_b, d_y
204
+
205
+ # Add to cache.
206
+ _bias_act_cuda_cache[key] = BiasActCuda
207
+ return BiasActCuda
208
+
209
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import contextlib
13
+ import torch
14
+
15
+ # pylint: disable=redefined-builtin
16
+ # pylint: disable=arguments-differ
17
+ # pylint: disable=protected-access
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ enabled = False # Enable the custom op by setting this to true.
22
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
23
+
24
+ @contextlib.contextmanager
25
+ def no_weight_gradients(disable=True):
26
+ global weight_gradients_disabled
27
+ old = weight_gradients_disabled
28
+ if disable:
29
+ weight_gradients_disabled = True
30
+ yield
31
+ weight_gradients_disabled = old
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
+
40
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
+ if _should_use_custom_op(input):
42
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def _should_use_custom_op(input):
48
+ assert isinstance(input, torch.Tensor)
49
+ if (not enabled) or (not torch.backends.cudnn.enabled):
50
+ return False
51
+ if input.device.type != 'cuda':
52
+ return False
53
+ return True
54
+
55
+ def _tuple_of_ints(xs, ndim):
56
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
57
+ assert len(xs) == ndim
58
+ assert all(isinstance(x, int) for x in xs)
59
+ return xs
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ _conv2d_gradfix_cache = dict()
64
+ _null_tensor = torch.empty([0])
65
+
66
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
67
+ # Parse arguments.
68
+ ndim = 2
69
+ weight_shape = tuple(weight_shape)
70
+ stride = _tuple_of_ints(stride, ndim)
71
+ padding = _tuple_of_ints(padding, ndim)
72
+ output_padding = _tuple_of_ints(output_padding, ndim)
73
+ dilation = _tuple_of_ints(dilation, ndim)
74
+
75
+ # Lookup from cache.
76
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
77
+ if key in _conv2d_gradfix_cache:
78
+ return _conv2d_gradfix_cache[key]
79
+
80
+ # Validate arguments.
81
+ assert groups >= 1
82
+ assert len(weight_shape) == ndim + 2
83
+ assert all(stride[i] >= 1 for i in range(ndim))
84
+ assert all(padding[i] >= 0 for i in range(ndim))
85
+ assert all(dilation[i] >= 0 for i in range(ndim))
86
+ if not transpose:
87
+ assert all(output_padding[i] == 0 for i in range(ndim))
88
+ else: # transpose
89
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
90
+
91
+ # Helpers.
92
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
93
+ def calc_output_padding(input_shape, output_shape):
94
+ if transpose:
95
+ return [0, 0]
96
+ return [
97
+ input_shape[i + 2]
98
+ - (output_shape[i + 2] - 1) * stride[i]
99
+ - (1 - 2 * padding[i])
100
+ - dilation[i] * (weight_shape[i + 2] - 1)
101
+ for i in range(ndim)
102
+ ]
103
+
104
+ # Forward & backward.
105
+ class Conv2d(torch.autograd.Function):
106
+ @staticmethod
107
+ def forward(ctx, input, weight, bias):
108
+ assert weight.shape == weight_shape
109
+ ctx.save_for_backward(
110
+ input if weight.requires_grad else _null_tensor,
111
+ weight if input.requires_grad else _null_tensor,
112
+ )
113
+ ctx.input_shape = input.shape
114
+
115
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
116
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
117
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
118
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
119
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
120
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
121
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
122
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
123
+
124
+ # General case => cuDNN.
125
+ if transpose:
126
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
127
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ input, weight = ctx.saved_tensors
132
+ input_shape = ctx.input_shape
133
+ grad_input = None
134
+ grad_weight = None
135
+ grad_bias = None
136
+
137
+ if ctx.needs_input_grad[0]:
138
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
139
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
140
+ grad_input = op.apply(grad_output, weight, None)
141
+ assert grad_input.shape == input_shape
142
+
143
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
144
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
145
+ assert grad_weight.shape == weight_shape
146
+
147
+ if ctx.needs_input_grad[2]:
148
+ grad_bias = grad_output.sum([0, 2, 3])
149
+
150
+ return grad_input, grad_weight, grad_bias
151
+
152
+ # Gradient with respect to the weights.
153
+ class Conv2dGradWeight(torch.autograd.Function):
154
+ @staticmethod
155
+ def forward(ctx, grad_output, input):
156
+ ctx.save_for_backward(
157
+ grad_output if input.requires_grad else _null_tensor,
158
+ input if grad_output.requires_grad else _null_tensor,
159
+ )
160
+ ctx.grad_output_shape = grad_output.shape
161
+ ctx.input_shape = input.shape
162
+
163
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
164
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
165
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
166
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
167
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
168
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
169
+
170
+ # General case => cuDNN.
171
+ name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
172
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
173
+ return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
174
+
175
+ @staticmethod
176
+ def backward(ctx, grad2_grad_weight):
177
+ grad_output, input = ctx.saved_tensors
178
+ grad_output_shape = ctx.grad_output_shape
179
+ input_shape = ctx.input_shape
180
+ grad2_grad_output = None
181
+ grad2_input = None
182
+
183
+ if ctx.needs_input_grad[0]:
184
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
185
+ assert grad2_grad_output.shape == grad_output_shape
186
+
187
+ if ctx.needs_input_grad[1]:
188
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
189
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
190
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
191
+ assert grad2_input.shape == input_shape
192
+
193
+ return grad2_grad_output, grad2_input
194
+
195
+ _conv2d_gradfix_cache[key] = Conv2d
196
+ return Conv2d
197
+
198
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ if not flip_weight and (kw > 1 or kh > 1):
37
+ w = w.flip([2, 3])
38
+
39
+ # Execute using conv2d_gradfix.
40
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41
+ return op(x, w, stride=stride, padding=padding, groups=groups)
42
+
43
+ #----------------------------------------------------------------------------
44
+
45
+ @misc.profiled_function
46
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47
+ r"""2D convolution with optional up/downsampling.
48
+
49
+ Padding is performed only once at the beginning, not between the operations.
50
+
51
+ Args:
52
+ x: Input tensor of shape
53
+ `[batch_size, in_channels, in_height, in_width]`.
54
+ w: Weight tensor of shape
55
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57
+ calling upfirdn2d.setup_filter(). None = identity (default).
58
+ up: Integer upsampling factor (default: 1).
59
+ down: Integer downsampling factor (default: 1).
60
+ padding: Padding with respect to the upsampled image. Can be a single number
61
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62
+ (default: 0).
63
+ groups: Split input channels into N groups (default: 1).
64
+ flip_weight: False = convolution, True = correlation (default: True).
65
+ flip_filter: False = convolution, True = correlation (default: False).
66
+
67
+ Returns:
68
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69
+ """
70
+ # Validate arguments.
71
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74
+ assert isinstance(up, int) and (up >= 1)
75
+ assert isinstance(down, int) and (down >= 1)
76
+ assert isinstance(groups, int) and (groups >= 1)
77
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78
+ fw, fh = _get_filter_size(f)
79
+ px0, px1, py0, py1 = _parse_padding(padding)
80
+
81
+ # Adjust padding to account for up/downsampling.
82
+ if up > 1:
83
+ px0 += (fw + up - 1) // 2
84
+ px1 += (fw - up) // 2
85
+ py0 += (fh + up - 1) // 2
86
+ py1 += (fh - up) // 2
87
+ if down > 1:
88
+ px0 += (fw - down + 1) // 2
89
+ px1 += (fw - down) // 2
90
+ py0 += (fh - down + 1) // 2
91
+ py1 += (fh - down) // 2
92
+
93
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
95
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97
+ return x
98
+
99
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
101
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103
+ return x
104
+
105
+ # Fast path: downsampling only => use strided convolution.
106
+ if down > 1 and up == 1:
107
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109
+ return x
110
+
111
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112
+ if up > 1:
113
+ if groups == 1:
114
+ w = w.transpose(0, 1)
115
+ else:
116
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117
+ w = w.transpose(1, 2)
118
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119
+ px0 -= kw - 1
120
+ px1 -= kw - up
121
+ py0 -= kh - 1
122
+ py1 -= kh - up
123
+ pxt = max(min(-px0, -px1), 0)
124
+ pyt = max(min(-py0, -py1), 0)
125
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127
+ if down > 1:
128
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129
+ return x
130
+
131
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132
+ if up == 1 and down == 1:
133
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135
+
136
+ # Fallback: Generic reference implementation.
137
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139
+ if down > 1:
140
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141
+ return x
142
+
143
+ #----------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.cpp ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "filtered_lrelu.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
17
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
18
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
19
+ {
20
+ // Set CUDA device.
21
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
23
+
24
+ // Validate arguments.
25
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
26
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
27
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
28
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
29
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
31
+ TORCH_CHECK(x.numel() > 0, "x is empty");
32
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
33
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
34
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
35
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
36
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
37
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
38
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
39
+
40
+ // Figure out how much shared memory is available on the device.
41
+ int maxSharedBytes = 0;
42
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
43
+ int sharedKB = maxSharedBytes >> 10;
44
+
45
+ // Populate enough launch parameters to check if a CUDA kernel exists.
46
+ filtered_lrelu_kernel_params p;
47
+ p.up = up;
48
+ p.down = down;
49
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
50
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
51
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
52
+ if (!test_spec.exec)
53
+ {
54
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
55
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
56
+ }
57
+
58
+ // Input/output element size.
59
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
60
+
61
+ // Input sizes.
62
+ int64_t xw = (int)x.size(3);
63
+ int64_t xh = (int)x.size(2);
64
+ int64_t fut_w = (int)fu.size(-1) - 1;
65
+ int64_t fut_h = (int)fu.size(0) - 1;
66
+ int64_t fdt_w = (int)fd.size(-1) - 1;
67
+ int64_t fdt_h = (int)fd.size(0) - 1;
68
+
69
+ // Logical size of upsampled buffer.
70
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
71
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
72
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
73
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
74
+
75
+ // Compute output size and allocate.
76
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
77
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
78
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
79
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
80
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
81
+
82
+ // Allocate sign tensor.
83
+ torch::Tensor so;
84
+ torch::Tensor s = si;
85
+ bool readSigns = !!s.numel();
86
+ int64_t sw_active = 0; // Active width of sign tensor.
87
+ if (writeSigns)
88
+ {
89
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
90
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
91
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
92
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
93
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
94
+ }
95
+ else if (readSigns)
96
+ sw_active = s.size(3) << 2;
97
+
98
+ // Validate sign tensor if in use.
99
+ if (readSigns || writeSigns)
100
+ {
101
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
102
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
103
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
104
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
105
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
106
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
107
+ }
108
+
109
+ // Populate rest of CUDA kernel parameters.
110
+ p.x = x.data_ptr();
111
+ p.y = y.data_ptr();
112
+ p.b = b.data_ptr();
113
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
114
+ p.fu = fu.data_ptr<float>();
115
+ p.fd = fd.data_ptr<float>();
116
+ p.pad0 = make_int2(px0, py0);
117
+ p.gain = gain;
118
+ p.slope = slope;
119
+ p.clamp = clamp;
120
+ p.flip = (flip_filters) ? 1 : 0;
121
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
122
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
123
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
124
+ p.sOfs = make_int2(sx, sy);
125
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
126
+
127
+ // x, y, b strides are in bytes.
128
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
129
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
130
+ p.bStride = sz * b.stride(0);
131
+
132
+ // fu, fd strides are in elements.
133
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
134
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
135
+
136
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
137
+ bool index64b = false;
138
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
139
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
140
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
141
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
142
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
143
+ if (s.numel() > INT_MAX) index64b = true;
144
+
145
+ // Choose CUDA kernel.
146
+ filtered_lrelu_kernel_spec spec = { 0 };
147
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
148
+ {
149
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
150
+ {
151
+ // Choose kernel based on index type, datatype and sign read/write modes.
152
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
153
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
154
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
155
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
156
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
157
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
158
+ }
159
+ });
160
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
161
+
162
+ // Launch CUDA kernel.
163
+ void* args[] = {&p};
164
+ int bx = spec.numWarps * 32;
165
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
166
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
167
+ int gz = p.yShape.z * p.yShape.w;
168
+
169
+ // Repeat multiple horizontal tiles in a CTA?
170
+ if (spec.xrep)
171
+ {
172
+ p.tilesXrep = spec.xrep;
173
+ p.tilesXdim = gx;
174
+
175
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
176
+ std::swap(gx, gy);
177
+ }
178
+ else
179
+ {
180
+ p.tilesXrep = 0;
181
+ p.tilesXdim = 0;
182
+ }
183
+
184
+ // Launch filter setup kernel.
185
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
186
+
187
+ // Copy kernels to constant memory.
188
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
189
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
190
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
191
+
192
+ // Set cache and shared memory configurations for main kernel.
193
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
194
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
195
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
196
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
197
+
198
+ // Launch main kernel.
199
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
200
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
201
+ {
202
+ p.blockZofs = zofs;
203
+ int subGz = std::min(maxSubGz, gz - zofs);
204
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
205
+ }
206
+
207
+ // Done.
208
+ return std::make_tuple(y, so, 0);
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+
213
+ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
214
+ {
215
+ // Set CUDA device.
216
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
217
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
218
+
219
+ // Validate arguments.
220
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
221
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
222
+ TORCH_CHECK(x.numel() > 0, "x is empty");
223
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
224
+
225
+ // Output signs if we don't have sign input.
226
+ torch::Tensor so;
227
+ torch::Tensor s = si;
228
+ bool readSigns = !!s.numel();
229
+ if (writeSigns)
230
+ {
231
+ int64_t sw = x.size(3);
232
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
233
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
234
+ }
235
+
236
+ // Validate sign tensor if in use.
237
+ if (readSigns || writeSigns)
238
+ {
239
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
240
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
241
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
242
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
243
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
244
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
245
+ }
246
+
247
+ // Initialize CUDA kernel parameters.
248
+ filtered_lrelu_act_kernel_params p;
249
+ p.x = x.data_ptr();
250
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
251
+ p.gain = gain;
252
+ p.slope = slope;
253
+ p.clamp = clamp;
254
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
255
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
256
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
257
+ p.sOfs = make_int2(sx, sy);
258
+
259
+ // Choose CUDA kernel.
260
+ void* func = 0;
261
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
262
+ {
263
+ if (writeSigns)
264
+ func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
265
+ else if (readSigns)
266
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
267
+ else
268
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
269
+ });
270
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
271
+
272
+ // Launch CUDA kernel.
273
+ void* args[] = {&p};
274
+ int bx = 128; // 4 warps per block.
275
+
276
+ // Logical size of launch = writeSigns ? p.s : p.x
277
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
278
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
279
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
280
+ gx = (gx - 1) / bx + 1;
281
+
282
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
283
+ const uint32_t gmax = 65535;
284
+ gy = std::min(gy, gmax);
285
+ gz = std::min(gz, gmax);
286
+
287
+ // Launch.
288
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
289
+ return so;
290
+ }
291
+
292
+ //------------------------------------------------------------------------
293
+
294
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
295
+ {
296
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
297
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
298
+ }
299
+
300
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.cu ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "filtered_lrelu.h"
11
+ #include <cstdint>
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ enum // Filter modes.
17
+ {
18
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
19
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
20
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
21
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
22
+ };
23
+
24
+ template <class T> struct InternalType;
25
+ template <> struct InternalType<double>
26
+ {
27
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
28
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
29
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
30
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
31
+ };
32
+ template <> struct InternalType<float>
33
+ {
34
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
35
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
36
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
37
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
38
+ };
39
+ template <> struct InternalType<c10::Half>
40
+ {
41
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
42
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
43
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
44
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
45
+ };
46
+
47
+ #define MIN(A, B) ((A) < (B) ? (A) : (B))
48
+ #define MAX(A, B) ((A) > (B) ? (A) : (B))
49
+ #define CEIL_DIV(A, B) (((B)==1) ? (A) : \
50
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
51
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
52
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
53
+
54
+ // This works only up to blocks of size 256 x 256 and for all N that are powers of two.
55
+ template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
56
+ {
57
+ if ((N & (N-1)) && N <= 256)
58
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
59
+ else
60
+ y = i/N;
61
+
62
+ x = i - y*N;
63
+ }
64
+
65
+ // Type cast stride before reading it.
66
+ template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
67
+ {
68
+ return *reinterpret_cast<const T*>(&x);
69
+ }
70
+
71
+ //------------------------------------------------------------------------
72
+ // Filters, setup kernel, copying function.
73
+
74
+ #define MAX_FILTER_SIZE 32
75
+
76
+ // Combined up/down filter buffers so that transfer can be done with one copy.
77
+ __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
78
+ __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
79
+
80
+ // Accessors to combined buffers to index up/down filters individually.
81
+ #define c_fu (c_fbuf)
82
+ #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
83
+ #define g_fu (g_fbuf)
84
+ #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
85
+
86
+ // Set up filters into global memory buffer.
87
+ static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
88
+ {
89
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
90
+ {
91
+ int x, y;
92
+ fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
93
+
94
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
95
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
96
+ if (p.fuShape.y > 0)
97
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
98
+ else
99
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
100
+
101
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
102
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
103
+ if (p.fdShape.y > 0)
104
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
105
+ else
106
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
107
+ }
108
+ }
109
+
110
+ // Host function to copy filters written by setup kernel into constant buffer for main kernel.
111
+ template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
112
+ {
113
+ void* src = 0;
114
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
115
+ if (err) return err;
116
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
117
+ }
118
+
119
+ //------------------------------------------------------------------------
120
+ // Coordinate spaces:
121
+ // - Relative to input tensor: inX, inY, tileInX, tileInY
122
+ // - Relative to input tile: relInX, relInY, tileInW, tileInH
123
+ // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
124
+ // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
125
+ // - Relative to output tensor: outX, outY, tileOutX, tileOutY
126
+ //
127
+ // Relationships between coordinate spaces:
128
+ // - inX = tileInX + relInX
129
+ // - inY = tileInY + relInY
130
+ // - relUpX = relInX * up + phaseInX
131
+ // - relUpY = relInY * up + phaseInY
132
+ // - relUpX = relOutX * down
133
+ // - relUpY = relOutY * down
134
+ // - outX = tileOutX + relOutX
135
+ // - outY = tileOutY + relOutY
136
+
137
+ extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
138
+
139
+ template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
140
+ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
141
+ {
142
+ // Check that we don't try to support non-existing filter modes.
143
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
144
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
145
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
146
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
147
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
148
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
149
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
150
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
151
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
152
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
153
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
154
+
155
+ // Static definitions.
156
+ typedef typename InternalType<T>::scalar_t scalar_t;
157
+ typedef typename InternalType<T>::vec2_t vec2_t;
158
+ typedef typename InternalType<T>::vec4_t vec4_t;
159
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
160
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
161
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
162
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
163
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
164
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
165
+
166
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
167
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
168
+
169
+ // Sizes of logical buffers.
170
+ const int szIn = tileInH_up * tileInW;
171
+ const int szUpX = tileInH_up * tileUpW;
172
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
173
+ const int szDownX = tileUpH * tileOutW;
174
+
175
+ // Sizes for shared memory arrays.
176
+ const int s_buf0_size_base =
177
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
178
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
179
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
180
+ (filterMode == MODE_FUFD) ? szIn :
181
+ -1;
182
+ const int s_buf1_size_base =
183
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
184
+ (filterMode == MODE_FUSD) ? szUpXY :
185
+ (filterMode == MODE_SUFD) ? szUpX :
186
+ (filterMode == MODE_FUFD) ? szUpXY :
187
+ -1;
188
+
189
+ // Ensure U128 alignment.
190
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
191
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
192
+
193
+ // Check at compile time that we don't use too much shared memory.
194
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
195
+
196
+ // Declare shared memory arrays.
197
+ scalar_t* s_buf0;
198
+ scalar_t* s_buf1;
199
+ if (sharedKB <= 48)
200
+ {
201
+ // Allocate shared memory arrays here.
202
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
203
+ s_buf0 = s_buf0_st;
204
+ s_buf1 = s_buf0 + s_buf0_size;
205
+ }
206
+ else
207
+ {
208
+ // Use the dynamically allocated shared memory array.
209
+ s_buf0 = (scalar_t*)s_buf_raw;
210
+ s_buf1 = s_buf0 + s_buf0_size;
211
+ }
212
+
213
+ // Pointers to the buffers.
214
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
215
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
216
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
217
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
218
+ if (filterMode == MODE_SUSD)
219
+ {
220
+ s_tileIn = s_buf0;
221
+ s_tileUpX = s_buf1;
222
+ s_tileUpXY = s_buf0;
223
+ s_tileDownX = s_buf1;
224
+ }
225
+ else if (filterMode == MODE_FUSD)
226
+ {
227
+ s_tileIn = s_buf0;
228
+ s_tileUpXY = s_buf1;
229
+ s_tileDownX = s_buf0;
230
+ }
231
+ else if (filterMode == MODE_SUFD)
232
+ {
233
+ s_tileIn = s_buf0;
234
+ s_tileUpX = s_buf1;
235
+ s_tileUpXY = s_buf0;
236
+ }
237
+ else if (filterMode == MODE_FUFD)
238
+ {
239
+ s_tileIn = s_buf0;
240
+ s_tileUpXY = s_buf1;
241
+ }
242
+
243
+ // Allow large grids in z direction via per-launch offset.
244
+ int channelIdx = blockIdx.z + p.blockZofs;
245
+ int batchIdx = channelIdx / p.yShape.z;
246
+ channelIdx -= batchIdx * p.yShape.z;
247
+
248
+ // Offset to output feature map. In bytes.
249
+ index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
250
+
251
+ // Sign shift amount.
252
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
253
+
254
+ // Inner tile loop.
255
+ #pragma unroll 1
256
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
257
+ {
258
+ // Locate output tile.
259
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
260
+ int tileOutX = tileX * tileOutW;
261
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
262
+
263
+ // Locate input tile.
264
+ int tmpX = tileOutX * down - p.pad0.x;
265
+ int tmpY = tileOutY * down - p.pad0.y;
266
+ int tileInX = CEIL_DIV(tmpX, up);
267
+ int tileInY = CEIL_DIV(tmpY, up);
268
+ const int phaseInX = tileInX * up - tmpX;
269
+ const int phaseInY = tileInY * up - tmpY;
270
+
271
+ // Extra sync if input and output buffers are the same and we are not on first tile.
272
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
273
+ __syncthreads();
274
+
275
+ // Load input tile & apply bias. Unrolled.
276
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
277
+ index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
278
+ int idx = threadIdx.x;
279
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
280
+ #pragma unroll
281
+ for (int loop = 0; loop < loopCountIN; loop++)
282
+ {
283
+ int relInX, relInY;
284
+ fast_div_mod<tileInW>(relInX, relInY, idx);
285
+ int inX = tileInX + relInX;
286
+ int inY = tileInY + relInY;
287
+ scalar_t v = 0;
288
+
289
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
290
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
291
+
292
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
293
+ if (!skip)
294
+ s_tileIn[idx] = v;
295
+
296
+ idx += threadsPerBlock;
297
+ }
298
+
299
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
300
+ {
301
+ // Horizontal upsampling.
302
+ __syncthreads();
303
+ if (up == 4)
304
+ {
305
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
306
+ {
307
+ int relUpX0, relInY;
308
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
309
+ int relInX0 = relUpX0 / up;
310
+ int src0 = relInX0 + tileInW * relInY;
311
+ int dst = relInY * tileUpW + relUpX0;
312
+ vec4_t v = InternalType<T>::zero_vec4();
313
+ scalar_t a = s_tileIn[src0];
314
+ if (phaseInX == 0)
315
+ {
316
+ #pragma unroll
317
+ for (int step = 0; step < fuSize / up; step++)
318
+ {
319
+ v.x += a * (scalar_t)c_fu[step * up + 0];
320
+ a = s_tileIn[src0 + step + 1];
321
+ v.y += a * (scalar_t)c_fu[step * up + 3];
322
+ v.z += a * (scalar_t)c_fu[step * up + 2];
323
+ v.w += a * (scalar_t)c_fu[step * up + 1];
324
+ }
325
+ }
326
+ else if (phaseInX == 1)
327
+ {
328
+ #pragma unroll
329
+ for (int step = 0; step < fuSize / up; step++)
330
+ {
331
+ v.x += a * (scalar_t)c_fu[step * up + 1];
332
+ v.y += a * (scalar_t)c_fu[step * up + 0];
333
+ a = s_tileIn[src0 + step + 1];
334
+ v.z += a * (scalar_t)c_fu[step * up + 3];
335
+ v.w += a * (scalar_t)c_fu[step * up + 2];
336
+ }
337
+ }
338
+ else if (phaseInX == 2)
339
+ {
340
+ #pragma unroll
341
+ for (int step = 0; step < fuSize / up; step++)
342
+ {
343
+ v.x += a * (scalar_t)c_fu[step * up + 2];
344
+ v.y += a * (scalar_t)c_fu[step * up + 1];
345
+ v.z += a * (scalar_t)c_fu[step * up + 0];
346
+ a = s_tileIn[src0 + step + 1];
347
+ v.w += a * (scalar_t)c_fu[step * up + 3];
348
+ }
349
+ }
350
+ else // (phaseInX == 3)
351
+ {
352
+ #pragma unroll
353
+ for (int step = 0; step < fuSize / up; step++)
354
+ {
355
+ v.x += a * (scalar_t)c_fu[step * up + 3];
356
+ v.y += a * (scalar_t)c_fu[step * up + 2];
357
+ v.z += a * (scalar_t)c_fu[step * up + 1];
358
+ v.w += a * (scalar_t)c_fu[step * up + 0];
359
+ a = s_tileIn[src0 + step + 1];
360
+ }
361
+ }
362
+ s_tileUpX[dst+0] = v.x;
363
+ s_tileUpX[dst+1] = v.y;
364
+ s_tileUpX[dst+2] = v.z;
365
+ s_tileUpX[dst+3] = v.w;
366
+ }
367
+ }
368
+ else if (up == 2)
369
+ {
370
+ bool p0 = (phaseInX == 0);
371
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
372
+ {
373
+ int relUpX0, relInY;
374
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
375
+ int relInX0 = relUpX0 / up;
376
+ int src0 = relInX0 + tileInW * relInY;
377
+ int dst = relInY * tileUpW + relUpX0;
378
+ vec2_t v = InternalType<T>::zero_vec2();
379
+ scalar_t a = s_tileIn[src0];
380
+ if (p0) // (phaseInX == 0)
381
+ {
382
+ #pragma unroll
383
+ for (int step = 0; step < fuSize / up; step++)
384
+ {
385
+ v.x += a * (scalar_t)c_fu[step * up + 0];
386
+ a = s_tileIn[src0 + step + 1];
387
+ v.y += a * (scalar_t)c_fu[step * up + 1];
388
+ }
389
+ }
390
+ else // (phaseInX == 1)
391
+ {
392
+ #pragma unroll
393
+ for (int step = 0; step < fuSize / up; step++)
394
+ {
395
+ v.x += a * (scalar_t)c_fu[step * up + 1];
396
+ v.y += a * (scalar_t)c_fu[step * up + 0];
397
+ a = s_tileIn[src0 + step + 1];
398
+ }
399
+ }
400
+ s_tileUpX[dst+0] = v.x;
401
+ s_tileUpX[dst+1] = v.y;
402
+ }
403
+ }
404
+
405
+ // Vertical upsampling & nonlinearity.
406
+
407
+ __syncthreads();
408
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
409
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
410
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
411
+ if (up == 4)
412
+ {
413
+ minY -= 3; // Adjust according to block height.
414
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
415
+ {
416
+ int relUpX, relInY0;
417
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
418
+ int relUpY0 = relInY0 * up;
419
+ int src0 = relInY0 * tileUpW + relUpX;
420
+ int dst = relUpY0 * tileUpW + relUpX;
421
+ vec4_t v = InternalType<T>::zero_vec4();
422
+
423
+ scalar_t a = s_tileUpX[src0];
424
+ if (phaseInY == 0)
425
+ {
426
+ #pragma unroll
427
+ for (int step = 0; step < fuSize / up; step++)
428
+ {
429
+ v.x += a * (scalar_t)c_fu[step * up + 0];
430
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
431
+ v.y += a * (scalar_t)c_fu[step * up + 3];
432
+ v.z += a * (scalar_t)c_fu[step * up + 2];
433
+ v.w += a * (scalar_t)c_fu[step * up + 1];
434
+ }
435
+ }
436
+ else if (phaseInY == 1)
437
+ {
438
+ #pragma unroll
439
+ for (int step = 0; step < fuSize / up; step++)
440
+ {
441
+ v.x += a * (scalar_t)c_fu[step * up + 1];
442
+ v.y += a * (scalar_t)c_fu[step * up + 0];
443
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
444
+ v.z += a * (scalar_t)c_fu[step * up + 3];
445
+ v.w += a * (scalar_t)c_fu[step * up + 2];
446
+ }
447
+ }
448
+ else if (phaseInY == 2)
449
+ {
450
+ #pragma unroll
451
+ for (int step = 0; step < fuSize / up; step++)
452
+ {
453
+ v.x += a * (scalar_t)c_fu[step * up + 2];
454
+ v.y += a * (scalar_t)c_fu[step * up + 1];
455
+ v.z += a * (scalar_t)c_fu[step * up + 0];
456
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
457
+ v.w += a * (scalar_t)c_fu[step * up + 3];
458
+ }
459
+ }
460
+ else // (phaseInY == 3)
461
+ {
462
+ #pragma unroll
463
+ for (int step = 0; step < fuSize / up; step++)
464
+ {
465
+ v.x += a * (scalar_t)c_fu[step * up + 3];
466
+ v.y += a * (scalar_t)c_fu[step * up + 2];
467
+ v.z += a * (scalar_t)c_fu[step * up + 1];
468
+ v.w += a * (scalar_t)c_fu[step * up + 0];
469
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
470
+ }
471
+ }
472
+
473
+ int x = tileOutX * down + relUpX;
474
+ int y = tileOutY * down + relUpY0;
475
+ int signX = x + p.sOfs.x;
476
+ int signY = y + p.sOfs.y;
477
+ int signZ = blockIdx.z + p.blockZofs;
478
+ int signXb = signX >> 2;
479
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
480
+ index_t si1 = si0 + p.sShape.x;
481
+ index_t si2 = si0 + p.sShape.x * 2;
482
+ index_t si3 = si0 + p.sShape.x * 3;
483
+
484
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
485
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
486
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
487
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
488
+
489
+ if (signWrite)
490
+ {
491
+ if (!enableWriteSkip)
492
+ {
493
+ // Determine and write signs.
494
+ int sx = __float_as_uint(v.x) >> 31 << 0;
495
+ int sy = __float_as_uint(v.y) >> 31 << 8;
496
+ int sz = __float_as_uint(v.z) >> 31 << 16;
497
+ int sw = __float_as_uint(v.w) >> 31 << 24;
498
+ if (sx) v.x *= p.slope;
499
+ if (sy) v.y *= p.slope;
500
+ if (sz) v.z *= p.slope;
501
+ if (sw) v.w *= p.slope;
502
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
503
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
504
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
505
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
506
+
507
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
508
+ {
509
+ // Combine signs.
510
+ uint32_t s = sx + sy + sw + sz;
511
+ s <<= (signX & 3) << 1;
512
+ s |= __shfl_xor_sync(groupMask, s, 1);
513
+ s |= __shfl_xor_sync(groupMask, s, 2);
514
+
515
+ // Write signs.
516
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
517
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
518
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
519
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
520
+ }
521
+ }
522
+ else
523
+ {
524
+ // Determine and write signs.
525
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
526
+ {
527
+ int sx = __float_as_uint(v.x) >> 31 << 0;
528
+ int sy = __float_as_uint(v.y) >> 31 << 8;
529
+ int sz = __float_as_uint(v.z) >> 31 << 16;
530
+ int sw = __float_as_uint(v.w) >> 31 << 24;
531
+ if (sx) v.x *= p.slope;
532
+ if (sy) v.y *= p.slope;
533
+ if (sz) v.z *= p.slope;
534
+ if (sw) v.w *= p.slope;
535
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
536
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
537
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
538
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
539
+
540
+ // Combine signs.
541
+ uint32_t s = sx + sy + sw + sz;
542
+ s <<= (signX & 3) << 1;
543
+ s |= __shfl_xor_sync(groupMask, s, 1);
544
+ s |= __shfl_xor_sync(groupMask, s, 2);
545
+
546
+ // Write signs.
547
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
548
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
549
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
550
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
551
+ }
552
+ else
553
+ {
554
+ // Just compute the values.
555
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
556
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
557
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
558
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
559
+ }
560
+ }
561
+ }
562
+ else if (signRead) // Read signs and apply.
563
+ {
564
+ if ((uint32_t)signXb < p.swLimit)
565
+ {
566
+ int ss = (signX & 3) << 1;
567
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
568
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
569
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
570
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
571
+ }
572
+ }
573
+ else // Forward pass with no sign write.
574
+ {
575
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
576
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
577
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
578
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
579
+ }
580
+
581
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
582
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
583
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
584
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
585
+ }
586
+ }
587
+ else if (up == 2)
588
+ {
589
+ minY -= 1; // Adjust according to block height.
590
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
591
+ {
592
+ int relUpX, relInY0;
593
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
594
+ int relUpY0 = relInY0 * up;
595
+ int src0 = relInY0 * tileUpW + relUpX;
596
+ int dst = relUpY0 * tileUpW + relUpX;
597
+ vec2_t v = InternalType<T>::zero_vec2();
598
+
599
+ scalar_t a = s_tileUpX[src0];
600
+ if (phaseInY == 0)
601
+ {
602
+ #pragma unroll
603
+ for (int step = 0; step < fuSize / up; step++)
604
+ {
605
+ v.x += a * (scalar_t)c_fu[step * up + 0];
606
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
607
+ v.y += a * (scalar_t)c_fu[step * up + 1];
608
+ }
609
+ }
610
+ else // (phaseInY == 1)
611
+ {
612
+ #pragma unroll
613
+ for (int step = 0; step < fuSize / up; step++)
614
+ {
615
+ v.x += a * (scalar_t)c_fu[step * up + 1];
616
+ v.y += a * (scalar_t)c_fu[step * up + 0];
617
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
618
+ }
619
+ }
620
+
621
+ int x = tileOutX * down + relUpX;
622
+ int y = tileOutY * down + relUpY0;
623
+ int signX = x + p.sOfs.x;
624
+ int signY = y + p.sOfs.y;
625
+ int signZ = blockIdx.z + p.blockZofs;
626
+ int signXb = signX >> 2;
627
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
628
+ index_t si1 = si0 + p.sShape.x;
629
+
630
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
631
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
632
+
633
+ if (signWrite)
634
+ {
635
+ if (!enableWriteSkip)
636
+ {
637
+ // Determine and write signs.
638
+ int sx = __float_as_uint(v.x) >> 31 << 0;
639
+ int sy = __float_as_uint(v.y) >> 31 << 8;
640
+ if (sx) v.x *= p.slope;
641
+ if (sy) v.y *= p.slope;
642
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
643
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
644
+
645
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
646
+ {
647
+ // Combine signs.
648
+ int s = sx + sy;
649
+ s <<= signXo;
650
+ s |= __shfl_xor_sync(groupMask, s, 1);
651
+ s |= __shfl_xor_sync(groupMask, s, 2);
652
+
653
+ // Write signs.
654
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
655
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
656
+ }
657
+ }
658
+ else
659
+ {
660
+ // Determine and write signs.
661
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
662
+ {
663
+ int sx = __float_as_uint(v.x) >> 31 << 0;
664
+ int sy = __float_as_uint(v.y) >> 31 << 8;
665
+ if (sx) v.x *= p.slope;
666
+ if (sy) v.y *= p.slope;
667
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
668
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
669
+
670
+ // Combine signs.
671
+ int s = sx + sy;
672
+ s <<= signXo;
673
+ s |= __shfl_xor_sync(groupMask, s, 1);
674
+ s |= __shfl_xor_sync(groupMask, s, 2);
675
+
676
+ // Write signs.
677
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
678
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
679
+ }
680
+ else
681
+ {
682
+ // Just compute the values.
683
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
684
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
685
+ }
686
+ }
687
+ }
688
+ else if (signRead) // Read signs and apply.
689
+ {
690
+ if ((uint32_t)signXb < p.swLimit)
691
+ {
692
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
693
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
694
+ }
695
+ }
696
+ else // Forward pass with no sign write.
697
+ {
698
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
699
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
700
+ }
701
+
702
+ if (!downInline)
703
+ {
704
+ // Write into temporary buffer.
705
+ s_tileUpXY[dst] = v.x;
706
+ if (relUpY0 < tileUpH - 1)
707
+ s_tileUpXY[dst + tileUpW] = v.y;
708
+ }
709
+ else
710
+ {
711
+ // Write directly into output buffer.
712
+ if ((uint32_t)x < p.yShape.x)
713
+ {
714
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
715
+ index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
716
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
717
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
718
+ }
719
+ }
720
+ }
721
+ }
722
+ }
723
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
724
+ {
725
+ // Full upsampling filter.
726
+
727
+ if (up == 2)
728
+ {
729
+ // 2 x 2-wide.
730
+ __syncthreads();
731
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
732
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
733
+ {
734
+ int relUpX0, relUpY0;
735
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
736
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
737
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
738
+ int src0 = relInX0 + tileInW * relInY0;
739
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
740
+
741
+ #define X_LOOP(TAPY, PX) \
742
+ for (int sx = 0; sx < fuSize / up; sx++) \
743
+ { \
744
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
745
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
746
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
747
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
748
+ }
749
+
750
+ vec4_t v = InternalType<T>::zero_vec4();
751
+ if (tap0y == 0 && phaseInX == 0)
752
+ #pragma unroll
753
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
754
+ #pragma unroll
755
+ X_LOOP(0, 0) }
756
+ if (tap0y == 0 && phaseInX == 1)
757
+ #pragma unroll
758
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
759
+ #pragma unroll
760
+ X_LOOP(0, 1) }
761
+ if (tap0y == 1 && phaseInX == 0)
762
+ #pragma unroll
763
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
764
+ #pragma unroll
765
+ X_LOOP(1, 0) }
766
+ if (tap0y == 1 && phaseInX == 1)
767
+ #pragma unroll
768
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
769
+ #pragma unroll
770
+ X_LOOP(1, 1) }
771
+
772
+ #undef X_LOOP
773
+
774
+ int x = tileOutX * down + relUpX0;
775
+ int y = tileOutY * down + relUpY0;
776
+ int signX = x + p.sOfs.x;
777
+ int signY = y + p.sOfs.y;
778
+ int signZ = blockIdx.z + p.blockZofs;
779
+ int signXb = signX >> 2;
780
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
781
+
782
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
783
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
784
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
785
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
786
+
787
+ if (signWrite)
788
+ {
789
+ if (!enableWriteSkip)
790
+ {
791
+ // Determine and write signs.
792
+ int sx = __float_as_uint(v.x) >> 31;
793
+ int sy = __float_as_uint(v.y) >> 31;
794
+ int sz = __float_as_uint(v.z) >> 31;
795
+ int sw = __float_as_uint(v.w) >> 31;
796
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
797
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
798
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
799
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
800
+
801
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
802
+ {
803
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
804
+ }
805
+ }
806
+ else
807
+ {
808
+ // Determine and write signs.
809
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
810
+ {
811
+ int sx = __float_as_uint(v.x) >> 31;
812
+ int sy = __float_as_uint(v.y) >> 31;
813
+ int sz = __float_as_uint(v.z) >> 31;
814
+ int sw = __float_as_uint(v.w) >> 31;
815
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
816
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
817
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
818
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
819
+
820
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
821
+ }
822
+ else
823
+ {
824
+ // Just compute the values.
825
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
826
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
827
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
828
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
829
+ }
830
+ }
831
+ }
832
+ else if (signRead) // Read sign and apply.
833
+ {
834
+ if ((uint32_t)signY < p.sShape.y)
835
+ {
836
+ int s = 0;
837
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
838
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
839
+ s >>= (signX & 3) << 1;
840
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
841
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
842
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
843
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
844
+ }
845
+ }
846
+ else // Forward pass with no sign write.
847
+ {
848
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
849
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
850
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
851
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
852
+ }
853
+
854
+ s_tileUpXY[idx + 0] = v.x;
855
+ s_tileUpXY[idx + 1] = v.y;
856
+ s_tileUpXY[idx + 2] = v.z;
857
+ s_tileUpXY[idx + 3] = v.w;
858
+ }
859
+ }
860
+ else if (up == 1)
861
+ {
862
+ __syncthreads();
863
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
864
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
865
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
866
+ {
867
+ int relUpX0, relUpY0;
868
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
869
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
870
+
871
+ int x = tileOutX * down + relUpX0;
872
+ int y = tileOutY * down + relUpY0;
873
+ int signX = x + p.sOfs.x;
874
+ int signY = y + p.sOfs.y;
875
+ int signZ = blockIdx.z + p.blockZofs;
876
+ int signXb = signX >> 2;
877
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
878
+ v *= (scalar_t)((float)up * (float)up * p.gain);
879
+
880
+ if (signWrite)
881
+ {
882
+ if (!enableWriteSkip)
883
+ {
884
+ // Determine and write sign.
885
+ uint32_t s = 0;
886
+ uint32_t signXbit = (1u << signXo);
887
+ if (v < 0.f)
888
+ {
889
+ s = signXbit;
890
+ v *= p.slope;
891
+ }
892
+ if (fabsf(v) > p.clamp)
893
+ {
894
+ s = signXbit * 2;
895
+ v = InternalType<T>::clamp(v, p.clamp);
896
+ }
897
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
898
+ {
899
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
900
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
901
+ p.s[si] = s; // Write.
902
+ }
903
+ }
904
+ else
905
+ {
906
+ // Determine and write sign.
907
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
908
+ {
909
+ uint32_t s = 0;
910
+ uint32_t signXbit = (1u << signXo);
911
+ if (v < 0.f)
912
+ {
913
+ s = signXbit;
914
+ v *= p.slope;
915
+ }
916
+ if (fabsf(v) > p.clamp)
917
+ {
918
+ s = signXbit * 2;
919
+ v = InternalType<T>::clamp(v, p.clamp);
920
+ }
921
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
922
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
923
+ p.s[si] = s; // Write.
924
+ }
925
+ else
926
+ {
927
+ // Just compute the value.
928
+ if (v < 0.f) v *= p.slope;
929
+ v = InternalType<T>::clamp(v, p.clamp);
930
+ }
931
+ }
932
+ }
933
+ else if (signRead)
934
+ {
935
+ // Read sign and apply if within sign tensor bounds.
936
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
937
+ {
938
+ int s = p.s[si];
939
+ s >>= signXo;
940
+ if (s & 1) v *= p.slope;
941
+ if (s & 2) v = 0.f;
942
+ }
943
+ }
944
+ else // Forward pass with no sign write.
945
+ {
946
+ if (v < 0.f) v *= p.slope;
947
+ v = InternalType<T>::clamp(v, p.clamp);
948
+ }
949
+
950
+ if (!downInline) // Write into temporary buffer.
951
+ s_tileUpXY[idx] = v;
952
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
953
+ *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
954
+ }
955
+ }
956
+ }
957
+
958
+ // Downsampling.
959
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
960
+ {
961
+ // Horizontal downsampling.
962
+ __syncthreads();
963
+ if (down == 4 && tileOutW % 4 == 0)
964
+ {
965
+ // Calculate 4 pixels at a time.
966
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
967
+ {
968
+ int relOutX0, relUpY;
969
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
970
+ int relUpX0 = relOutX0 * down;
971
+ int src0 = relUpY * tileUpW + relUpX0;
972
+ vec4_t v = InternalType<T>::zero_vec4();
973
+ #pragma unroll
974
+ for (int step = 0; step < fdSize; step++)
975
+ {
976
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
977
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
978
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
979
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
980
+ }
981
+ s_tileDownX[idx+0] = v.x;
982
+ s_tileDownX[idx+1] = v.y;
983
+ s_tileDownX[idx+2] = v.z;
984
+ s_tileDownX[idx+3] = v.w;
985
+ }
986
+ }
987
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
988
+ {
989
+ // Calculate 2 pixels at a time.
990
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
991
+ {
992
+ int relOutX0, relUpY;
993
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
994
+ int relUpX0 = relOutX0 * down;
995
+ int src0 = relUpY * tileUpW + relUpX0;
996
+ vec2_t v = InternalType<T>::zero_vec2();
997
+ #pragma unroll
998
+ for (int step = 0; step < fdSize; step++)
999
+ {
1000
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
1001
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
1002
+ }
1003
+ s_tileDownX[idx+0] = v.x;
1004
+ s_tileDownX[idx+1] = v.y;
1005
+ }
1006
+ }
1007
+ else
1008
+ {
1009
+ // Calculate 1 pixel at a time.
1010
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
1011
+ {
1012
+ int relOutX0, relUpY;
1013
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
1014
+ int relUpX0 = relOutX0 * down;
1015
+ int src = relUpY * tileUpW + relUpX0;
1016
+ scalar_t v = 0.f;
1017
+ #pragma unroll
1018
+ for (int step = 0; step < fdSize; step++)
1019
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
1020
+ s_tileDownX[idx] = v;
1021
+ }
1022
+ }
1023
+
1024
+ // Vertical downsampling & store output tile.
1025
+ __syncthreads();
1026
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1027
+ {
1028
+ int relOutX, relOutY0;
1029
+ fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
1030
+ int relUpY0 = relOutY0 * down;
1031
+ int src0 = relUpY0 * tileOutW + relOutX;
1032
+ scalar_t v = 0;
1033
+ #pragma unroll
1034
+ for (int step = 0; step < fdSize; step++)
1035
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
1036
+
1037
+ int outX = tileOutX + relOutX;
1038
+ int outY = tileOutY + relOutY0;
1039
+
1040
+ if (outX < p.yShape.x & outY < p.yShape.y)
1041
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1042
+ }
1043
+ }
1044
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
1045
+ {
1046
+ // Full downsampling filter.
1047
+ if (down == 2)
1048
+ {
1049
+ // 2-wide.
1050
+ __syncthreads();
1051
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
1052
+ {
1053
+ int relOutX0, relOutY0;
1054
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1055
+ int relUpX0 = relOutX0 * down;
1056
+ int relUpY0 = relOutY0 * down;
1057
+ int src0 = relUpY0 * tileUpW + relUpX0;
1058
+ vec2_t v = InternalType<T>::zero_vec2();
1059
+ #pragma unroll
1060
+ for (int sy = 0; sy < fdSize; sy++)
1061
+ #pragma unroll
1062
+ for (int sx = 0; sx < fdSize; sx++)
1063
+ {
1064
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1065
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1066
+ }
1067
+
1068
+ int outX = tileOutX + relOutX0;
1069
+ int outY = tileOutY + relOutY0;
1070
+ if ((uint32_t)outY < p.yShape.y)
1071
+ {
1072
+ index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
1073
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
1074
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
1075
+ }
1076
+ }
1077
+ }
1078
+ else if (down == 1 && !downInline)
1079
+ {
1080
+ // Thread per pixel.
1081
+ __syncthreads();
1082
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1083
+ {
1084
+ int relOutX0, relOutY0;
1085
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1086
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
1087
+
1088
+ int outX = tileOutX + relOutX0;
1089
+ int outY = tileOutY + relOutY0;
1090
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
1091
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1092
+ }
1093
+ }
1094
+ }
1095
+
1096
+ if (!enableXrep)
1097
+ break;
1098
+ }
1099
+ }
1100
+
1101
+ //------------------------------------------------------------------------
1102
+ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
1103
+ // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
1104
+
1105
+ template <class T, bool signWrite, bool signRead>
1106
+ static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
1107
+ {
1108
+ typedef typename InternalType<T>::scalar_t scalar_t;
1109
+
1110
+ // Indexing.
1111
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
1112
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
1113
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
1114
+
1115
+ // Loop to accommodate oversized tensors.
1116
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
1117
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
1118
+ {
1119
+ // Extract z and w (channel, minibatch index).
1120
+ int32_t w = q / p.xShape.z;
1121
+ int32_t z = q - w * p.xShape.z;
1122
+
1123
+ // Choose behavior based on sign read/write mode.
1124
+ if (signWrite)
1125
+ {
1126
+ // Process value if in p.x.
1127
+ uint32_t s = 0;
1128
+ if (x < p.xShape.x && y < p.xShape.y)
1129
+ {
1130
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1131
+ T* pv = ((T*)p.x) + ix;
1132
+ scalar_t v = (scalar_t)(*pv);
1133
+
1134
+ // Gain, LReLU, clamp.
1135
+ v *= p.gain;
1136
+ if (v < 0.f)
1137
+ {
1138
+ v *= p.slope;
1139
+ s = 1; // Sign.
1140
+ }
1141
+ if (fabsf(v) > p.clamp)
1142
+ {
1143
+ v = InternalType<T>::clamp(v, p.clamp);
1144
+ s = 2; // Clamp.
1145
+ }
1146
+
1147
+ *pv = (T)v; // Write value.
1148
+ }
1149
+
1150
+ // Coalesce into threads 0 and 16 of warp.
1151
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
1152
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
1153
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
1154
+ s |= __shfl_xor_sync(m, s, 2);
1155
+ s |= __shfl_xor_sync(m, s, 4);
1156
+ s |= __shfl_xor_sync(m, s, 8);
1157
+
1158
+ // Write signs if leader and in p.s.
1159
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
1160
+ {
1161
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
1162
+ ((uint32_t*)p.s)[is >> 4] = s;
1163
+ }
1164
+ }
1165
+ else if (signRead)
1166
+ {
1167
+ // Process value if in p.x.
1168
+ if (x < p.xShape.x) // y is always in.
1169
+ {
1170
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1171
+ T* pv = ((T*)p.x) + ix;
1172
+ scalar_t v = (scalar_t)(*pv);
1173
+ v *= p.gain;
1174
+
1175
+ // Apply sign buffer offset.
1176
+ uint32_t sx = x + p.sOfs.x;
1177
+ uint32_t sy = y + p.sOfs.y;
1178
+
1179
+ // Read and apply signs if we land inside valid region of sign buffer.
1180
+ if (sx < p.sShape.x && sy < p.sShape.y)
1181
+ {
1182
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
1183
+ unsigned char s = p.s[is];
1184
+ s >>= (sx & 3) << 1; // Shift into place.
1185
+ if (s & 1) // Sign?
1186
+ v *= p.slope;
1187
+ if (s & 2) // Clamp?
1188
+ v = 0.f;
1189
+ }
1190
+
1191
+ *pv = (T)v; // Write value.
1192
+ }
1193
+ }
1194
+ else
1195
+ {
1196
+ // Forward pass with no sign write. Process value if in p.x.
1197
+ if (x < p.xShape.x) // y is always in.
1198
+ {
1199
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1200
+ T* pv = ((T*)p.x) + ix;
1201
+ scalar_t v = (scalar_t)(*pv);
1202
+ v *= p.gain;
1203
+ if (v < 0.f)
1204
+ v *= p.slope;
1205
+ if (fabsf(v) > p.clamp)
1206
+ v = InternalType<T>::clamp(v, p.clamp);
1207
+ *pv = (T)v; // Write value.
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+
1213
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
1214
+ {
1215
+ return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
1216
+ }
1217
+
1218
+ //------------------------------------------------------------------------
1219
+ // CUDA kernel selection.
1220
+
1221
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
1222
+ {
1223
+ filtered_lrelu_kernel_spec s = { 0 };
1224
+
1225
+ // Return the first matching kernel.
1226
+ #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
1227
+ if (sharedKB >= SH) \
1228
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
1229
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
1230
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
1231
+ { \
1232
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
1233
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
1234
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
1235
+ s.setup = (void*)setup_filters_kernel; \
1236
+ s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
1237
+ s.tileOut = make_int2(TW, TH); \
1238
+ s.numWarps = W; \
1239
+ s.xrep = XR; \
1240
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
1241
+ return s; \
1242
+ }
1243
+
1244
+ // Launch parameters for various kernel specializations.
1245
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
1246
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
1247
+
1248
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
1249
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
1250
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
1251
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
1252
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
1253
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
1254
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
1255
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
1256
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
1257
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
1258
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
1259
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
1260
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
1261
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
1262
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
1263
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
1264
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
1265
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
1266
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
1267
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
1268
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
1269
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
1270
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
1271
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
1272
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
1273
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
1274
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
1275
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
1276
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
1277
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
1278
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
1279
+
1280
+ #undef CASE
1281
+ return s; // No kernel found.
1282
+ }
1283
+
1284
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct filtered_lrelu_kernel_params
15
+ {
16
+ // These parameters decide which kernel to use.
17
+ int up; // upsampling ratio (1, 2, 4)
18
+ int down; // downsampling ratio (1, 2, 4)
19
+ int2 fuShape; // [size, 1] | [size, size]
20
+ int2 fdShape; // [size, 1] | [size, size]
21
+
22
+ int _dummy; // Alignment.
23
+
24
+ // Rest of the parameters.
25
+ const void* x; // Input tensor.
26
+ void* y; // Output tensor.
27
+ const void* b; // Bias tensor.
28
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
29
+ const float* fu; // Upsampling filter.
30
+ const float* fd; // Downsampling filter.
31
+
32
+ int2 pad0; // Left/top padding.
33
+ float gain; // Additional gain factor.
34
+ float slope; // Leaky ReLU slope on negative side.
35
+ float clamp; // Clamp after nonlinearity.
36
+ int flip; // Filter kernel flip for gradient computation.
37
+
38
+ int tilesXdim; // Original number of horizontal output tiles.
39
+ int tilesXrep; // Number of horizontal tiles per CTA.
40
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41
+
42
+ int4 xShape; // [width, height, channel, batch]
43
+ int4 yShape; // [width, height, channel, batch]
44
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46
+ int swLimit; // Active width of sign tensor in bytes.
47
+
48
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49
+ longlong4 yStride; //
50
+ int64_t bStride; //
51
+ longlong3 fuStride; //
52
+ longlong3 fdStride; //
53
+ };
54
+
55
+ struct filtered_lrelu_act_kernel_params
56
+ {
57
+ void* x; // Input/output, modified in-place.
58
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
59
+
60
+ float gain; // Additional gain factor.
61
+ float slope; // Leaky ReLU slope on negative side.
62
+ float clamp; // Clamp after nonlinearity.
63
+
64
+ int4 xShape; // [width, height, channel, batch]
65
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
66
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68
+ };
69
+
70
+ //------------------------------------------------------------------------
71
+ // CUDA kernel specialization.
72
+
73
+ struct filtered_lrelu_kernel_spec
74
+ {
75
+ void* setup; // Function for filter kernel setup.
76
+ void* exec; // Function for main operation.
77
+ int2 tileOut; // Width/height of launch tile.
78
+ int numWarps; // Number of warps per thread block, determines launch block size.
79
+ int xrep; // For processing multiple horizontal tiles per thread block.
80
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81
+ };
82
+
83
+ //------------------------------------------------------------------------
84
+ // CUDA kernel selection.
85
+
86
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
88
+ template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
89
+
90
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import torch
12
+ import warnings
13
+
14
+ from .. import custom_ops
15
+ from .. import misc
16
+ from . import upfirdn2d
17
+ from . import bias_act
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='filtered_lrelu_plugin',
28
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
29
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math'],
32
+ )
33
+ return True
34
+
35
+ def _get_filter_size(f):
36
+ if f is None:
37
+ return 1, 1
38
+ assert isinstance(f, torch.Tensor)
39
+ assert 1 <= f.ndim <= 2
40
+ return f.shape[-1], f.shape[0] # width, height
41
+
42
+ def _parse_padding(padding):
43
+ if isinstance(padding, int):
44
+ padding = [padding, padding]
45
+ assert isinstance(padding, (list, tuple))
46
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
47
+ padding = [int(x) for x in padding]
48
+ if len(padding) == 2:
49
+ px, py = padding
50
+ padding = [px, px, py, py]
51
+ px0, px1, py0, py1 = padding
52
+ return px0, px1, py0, py1
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
57
+ r"""Filtered leaky ReLU for a batch of 2D images.
58
+
59
+ Performs the following sequence of operations for each channel:
60
+
61
+ 1. Add channel-specific bias if provided (`b`).
62
+
63
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
64
+
65
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
66
+ Negative padding corresponds to cropping the image.
67
+
68
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
69
+ so that the footprint of all output pixels lies within the input image.
70
+
71
+ 5. Multiply each value by the provided gain factor (`gain`).
72
+
73
+ 6. Apply leaky ReLU activation function to each value.
74
+
75
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
76
+
77
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
78
+ it so that the footprint of all output pixels lies within the input image.
79
+
80
+ 9. Downsample the image by keeping every Nth pixel (`down`).
81
+
82
+ The fused op is considerably more efficient than performing the same calculation
83
+ using standard PyTorch ops. It supports gradients of arbitrary order.
84
+
85
+ Args:
86
+ x: Float32/float16/float64 input tensor of the shape
87
+ `[batch_size, num_channels, in_height, in_width]`.
88
+ fu: Float32 upsampling FIR filter of the shape
89
+ `[filter_height, filter_width]` (non-separable),
90
+ `[filter_taps]` (separable), or
91
+ `None` (identity).
92
+ fd: Float32 downsampling FIR filter of the shape
93
+ `[filter_height, filter_width]` (non-separable),
94
+ `[filter_taps]` (separable), or
95
+ `None` (identity).
96
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
97
+ as `x`. The length of vector must must match the channel dimension of `x`.
98
+ up: Integer upsampling factor (default: 1).
99
+ down: Integer downsampling factor. (default: 1).
100
+ padding: Padding with respect to the upsampled image. Can be a single number
101
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
102
+ (default: 0).
103
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
104
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
105
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
106
+ flip_filter: False = convolution, True = correlation (default: False).
107
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
108
+
109
+ Returns:
110
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
111
+ """
112
+ assert isinstance(x, torch.Tensor)
113
+ assert impl in ['ref', 'cuda']
114
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
115
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
116
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ @misc.profiled_function
121
+ def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
122
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
123
+ existing `upfirdn2n()` and `bias_act()` ops.
124
+ """
125
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
126
+ fu_w, fu_h = _get_filter_size(fu)
127
+ fd_w, fd_h = _get_filter_size(fd)
128
+ if b is not None:
129
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
130
+ misc.assert_shape(b, [x.shape[1]])
131
+ assert isinstance(up, int) and up >= 1
132
+ assert isinstance(down, int) and down >= 1
133
+ px0, px1, py0, py1 = _parse_padding(padding)
134
+ assert gain == float(gain) and gain > 0
135
+ assert slope == float(slope) and slope >= 0
136
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
137
+
138
+ # Calculate output size.
139
+ batch_size, channels, in_h, in_w = x.shape
140
+ in_dtype = x.dtype
141
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
142
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
143
+
144
+ # Compute using existing ops.
145
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
146
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
147
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
148
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
149
+
150
+ # Check output shape & dtype.
151
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
152
+ assert x.dtype == in_dtype
153
+ return x
154
+
155
+ #----------------------------------------------------------------------------
156
+
157
+ _filtered_lrelu_cuda_cache = dict()
158
+
159
+ def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
160
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
161
+ """
162
+ assert isinstance(up, int) and up >= 1
163
+ assert isinstance(down, int) and down >= 1
164
+ px0, px1, py0, py1 = _parse_padding(padding)
165
+ assert gain == float(gain) and gain > 0
166
+ gain = float(gain)
167
+ assert slope == float(slope) and slope >= 0
168
+ slope = float(slope)
169
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
170
+ clamp = float(clamp if clamp is not None else 'inf')
171
+
172
+ # Lookup from cache.
173
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
174
+ if key in _filtered_lrelu_cuda_cache:
175
+ return _filtered_lrelu_cuda_cache[key]
176
+
177
+ # Forward op.
178
+ class FilteredLReluCuda(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
181
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
182
+
183
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
184
+ if fu is None:
185
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
186
+ if fd is None:
187
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
188
+ assert 1 <= fu.ndim <= 2
189
+ assert 1 <= fd.ndim <= 2
190
+
191
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
192
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
193
+ fu = fu.square()[None]
194
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
195
+ fd = fd.square()[None]
196
+
197
+ # Missing sign input tensor.
198
+ if si is None:
199
+ si = torch.empty([0])
200
+
201
+ # Missing bias tensor.
202
+ if b is None:
203
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
204
+
205
+ # Construct internal sign tensor only if gradients are needed.
206
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
207
+
208
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
209
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
210
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
211
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
212
+
213
+ # Call C++/Cuda plugin if datatype is supported.
214
+ if x.dtype in [torch.float16, torch.float32]:
215
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
216
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
217
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
218
+ else:
219
+ return_code = -1
220
+
221
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
222
+ # only the bit-packed sign tensor is retained for gradient computation.
223
+ if return_code < 0:
224
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
225
+
226
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
227
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
228
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
229
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
230
+
231
+ # Prepare for gradient computation.
232
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
233
+ ctx.x_shape = x.shape
234
+ ctx.y_shape = y.shape
235
+ ctx.s_ofs = sx, sy
236
+ return y
237
+
238
+ @staticmethod
239
+ def backward(ctx, dy): # pylint: disable=arguments-differ
240
+ fu, fd, si = ctx.saved_tensors
241
+ _, _, xh, xw = ctx.x_shape
242
+ _, _, yh, yw = ctx.y_shape
243
+ sx, sy = ctx.s_ofs
244
+ dx = None # 0
245
+ dfu = None; assert not ctx.needs_input_grad[1]
246
+ dfd = None; assert not ctx.needs_input_grad[2]
247
+ db = None # 3
248
+ dsi = None; assert not ctx.needs_input_grad[4]
249
+ dsx = None; assert not ctx.needs_input_grad[5]
250
+ dsy = None; assert not ctx.needs_input_grad[6]
251
+
252
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
253
+ pp = [
254
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
255
+ xw * up - yw * down + px0 - (up - 1),
256
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
257
+ xh * up - yh * down + py0 - (up - 1),
258
+ ]
259
+ gg = gain * (up ** 2) / (down ** 2)
260
+ ff = (not flip_filter)
261
+ sx = sx - (fu.shape[-1] - 1) + px0
262
+ sy = sy - (fu.shape[0] - 1) + py0
263
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
264
+
265
+ if ctx.needs_input_grad[3]:
266
+ db = dx.sum([0, 2, 3])
267
+
268
+ return dx, dfu, dfd, db, dsi, dsx, dsy
269
+
270
+ # Add to cache.
271
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
272
+ return FilteredLReluCuda
273
+
274
+ #----------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu_ns.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for no signs mode (no gradients required).
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, false>(cudaStream_t stream);
torch_utils/ops/filtered_lrelu_rd.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign read mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, true>(cudaStream_t stream);
torch_utils/ops/filtered_lrelu_wr.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign write mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<true, false>(cudaStream_t stream);
torch_utils/ops/fma.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
+
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+
15
+ def fma(a, b, c): # => a * b + c
16
+ return _FusedMultiplyAdd.apply(a, b, c)
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
+ @staticmethod
22
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
+ out = torch.addcmul(c, a, b)
24
+ ctx.save_for_backward(a, b)
25
+ ctx.c_shape = c.shape
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx, dout): # pylint: disable=arguments-differ
30
+ a, b = ctx.saved_tensors
31
+ c_shape = ctx.c_shape
32
+ da = None
33
+ db = None
34
+ dc = None
35
+
36
+ if ctx.needs_input_grad[0]:
37
+ da = _unbroadcast(dout * b, a.shape)
38
+
39
+ if ctx.needs_input_grad[1]:
40
+ db = _unbroadcast(dout * a, b.shape)
41
+
42
+ if ctx.needs_input_grad[2]:
43
+ dc = _unbroadcast(dout, c_shape)
44
+
45
+ return da, db, dc
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _unbroadcast(x, shape):
50
+ extra_dims = x.ndim - len(shape)
51
+ assert extra_dims >= 0
52
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
+ if len(dim):
54
+ x = x.sum(dim=dim, keepdim=True)
55
+ if extra_dims:
56
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
57
+ assert x.shape == shape
58
+ return x
59
+
60
+ #----------------------------------------------------------------------------
torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import torch
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ def grid_sample(input, grid):
27
+ if _should_use_custom_op():
28
+ return _GridSample2dForward.apply(input, grid)
29
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def _should_use_custom_op():
34
+ return enabled
35
+
36
+ #----------------------------------------------------------------------------
37
+
38
+ class _GridSample2dForward(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward(ctx, input, grid):
41
+ assert input.ndim == 4
42
+ assert grid.ndim == 4
43
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
44
+ ctx.save_for_backward(input, grid)
45
+ return output
46
+
47
+ @staticmethod
48
+ def backward(ctx, grad_output):
49
+ input, grid = ctx.saved_tensors
50
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
51
+ return grad_input, grad_grid
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ class _GridSample2dBackward(torch.autograd.Function):
56
+ @staticmethod
57
+ def forward(ctx, grad_output, input, grid):
58
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
59
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
60
+ ctx.save_for_backward(grid)
61
+ return grad_input, grad_grid
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
65
+ _ = grad2_grad_grid # unused
66
+ grid, = ctx.saved_tensors
67
+ grad2_grad_output = None
68
+ grad2_input = None
69
+ grad2_grid = None
70
+
71
+ if ctx.needs_input_grad[0]:
72
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
73
+
74
+ assert not ctx.needs_input_grad[2]
75
+ return grad2_grad_output, grad2_input, grad2_grid
76
+
77
+ #----------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
25
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
26
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32
+
33
+ // Create output tensor.
34
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41
+
42
+ // Initialize CUDA kernel parameters.
43
+ upfirdn2d_kernel_params p;
44
+ p.x = x.data_ptr();
45
+ p.f = f.data_ptr<float>();
46
+ p.y = y.data_ptr();
47
+ p.up = make_int2(upx, upy);
48
+ p.down = make_int2(downx, downy);
49
+ p.pad0 = make_int2(padx0, pady0);
50
+ p.flip = (flip) ? 1 : 0;
51
+ p.gain = gain;
52
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60
+
61
+ // Choose CUDA kernel.
62
+ upfirdn2d_kernel_spec spec;
63
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64
+ {
65
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
66
+ });
67
+
68
+ // Set looping options.
69
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70
+ p.loopMinor = spec.loopMinor;
71
+ p.loopX = spec.loopX;
72
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74
+
75
+ // Compute grid size.
76
+ dim3 blockSize, gridSize;
77
+ if (spec.tileOutW < 0) // large
78
+ {
79
+ blockSize = dim3(4, 32, 1);
80
+ gridSize = dim3(
81
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83
+ p.launchMajor);
84
+ }
85
+ else // small
86
+ {
87
+ blockSize = dim3(256, 1, 1);
88
+ gridSize = dim3(
89
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91
+ p.launchMajor);
92
+ }
93
+
94
+ // Launch CUDA kernel.
95
+ void* args[] = {&p};
96
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97
+ return y;
98
+ }
99
+
100
+ //------------------------------------------------------------------------
101
+
102
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103
+ {
104
+ m.def("upfirdn2d", &upfirdn2d);
105
+ }
106
+
107
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
209
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
210
+
211
+ // No up/downsampling.
212
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
213
+ {
214
+ // contiguous
215
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
216
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
217
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
218
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
219
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
220
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
221
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
222
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
223
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
225
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
226
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
228
+ // channels_last
229
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
230
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
231
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
232
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
233
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
234
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
236
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
237
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
238
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
239
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
240
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
241
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
242
+ }
243
+
244
+ // 2x upsampling.
245
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
246
+ {
247
+ // contiguous
248
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
249
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
250
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ // channels_last
255
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
256
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
257
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
263
+ {
264
+ // contiguous
265
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
266
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
268
+ // channels_last
269
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
270
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
271
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
272
+ }
273
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
274
+ {
275
+ // contiguous
276
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
277
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
278
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
279
+ // channels_last
280
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
281
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
282
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
283
+ }
284
+
285
+ // 2x downsampling.
286
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
287
+ {
288
+ // contiguous
289
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
290
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
291
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
292
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
293
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
294
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
295
+ // channels_last
296
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
297
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
298
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
299
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
300
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
301
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
302
+ }
303
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
304
+ {
305
+ // contiguous
306
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
307
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
308
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
309
+ // channels_last
310
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
311
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
312
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
313
+ }
314
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
315
+ {
316
+ // contiguous
317
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
318
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
319
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
320
+ // channels_last
321
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
322
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
323
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
324
+ }
325
+
326
+ // 4x upsampling.
327
+ if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
328
+ {
329
+ // contiguous
330
+ if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
331
+ if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
332
+ // channels_last
333
+ if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
334
+ if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
335
+ }
336
+ if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
337
+ {
338
+ // contiguous
339
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
340
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
341
+ // channels_last
342
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
343
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
344
+ }
345
+ if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
346
+ {
347
+ // contiguous
348
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
349
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
350
+ // channels_last
351
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
352
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
353
+ }
354
+
355
+ // 4x downsampling (inefficient).
356
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
357
+ {
358
+ // contiguous
359
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
360
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
361
+ // channels_last
362
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
363
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
364
+ }
365
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
366
+ {
367
+ // contiguous
368
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
369
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
370
+ // channels_last
371
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
372
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
373
+ }
374
+ return spec;
375
+ }
376
+
377
+ //------------------------------------------------------------------------
378
+ // Template specializations.
379
+
380
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
381
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
382
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
383
+
384
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+
15
+ from .. import custom_ops
16
+ from .. import misc
17
+ from . import conv2d_gradfix
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='upfirdn2d_plugin',
28
+ sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
29
+ headers=['upfirdn2d.h'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math'],
32
+ )
33
+ return True
34
+
35
+ def _parse_scaling(scaling):
36
+ if isinstance(scaling, int):
37
+ scaling = [scaling, scaling]
38
+ assert isinstance(scaling, (list, tuple))
39
+ assert all(isinstance(x, int) for x in scaling)
40
+ sx, sy = scaling
41
+ assert sx >= 1 and sy >= 1
42
+ return sx, sy
43
+
44
+ def _parse_padding(padding):
45
+ if isinstance(padding, int):
46
+ padding = [padding, padding]
47
+ assert isinstance(padding, (list, tuple))
48
+ assert all(isinstance(x, int) for x in padding)
49
+ if len(padding) == 2:
50
+ padx, pady = padding
51
+ padding = [padx, padx, pady, pady]
52
+ padx0, padx1, pady0, pady1 = padding
53
+ return padx0, padx1, pady0, pady1
54
+
55
+ def _get_filter_size(f):
56
+ if f is None:
57
+ return 1, 1
58
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
59
+ fw = f.shape[-1]
60
+ fh = f.shape[0]
61
+ with misc.suppress_tracer_warnings():
62
+ fw = int(fw)
63
+ fh = int(fh)
64
+ misc.assert_shape(f, [fh, fw][:f.ndim])
65
+ assert fw >= 1 and fh >= 1
66
+ return fw, fh
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
71
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
72
+
73
+ Args:
74
+ f: Torch tensor, numpy array, or python list of the shape
75
+ `[filter_height, filter_width]` (non-separable),
76
+ `[filter_taps]` (separable),
77
+ `[]` (impulse), or
78
+ `None` (identity).
79
+ device: Result device (default: cpu).
80
+ normalize: Normalize the filter so that it retains the magnitude
81
+ for constant input signal (DC)? (default: True).
82
+ flip_filter: Flip the filter? (default: False).
83
+ gain: Overall scaling factor for signal magnitude (default: 1).
84
+ separable: Return a separable filter? (default: select automatically).
85
+
86
+ Returns:
87
+ Float32 tensor of the shape
88
+ `[filter_height, filter_width]` (non-separable) or
89
+ `[filter_taps]` (separable).
90
+ """
91
+ # Validate.
92
+ if f is None:
93
+ f = 1
94
+ f = torch.as_tensor(f, dtype=torch.float32)
95
+ assert f.ndim in [0, 1, 2]
96
+ assert f.numel() > 0
97
+ if f.ndim == 0:
98
+ f = f[np.newaxis]
99
+
100
+ # Separable?
101
+ if separable is None:
102
+ separable = (f.ndim == 1 and f.numel() >= 8)
103
+ if f.ndim == 1 and not separable:
104
+ f = f.ger(f)
105
+ assert f.ndim == (1 if separable else 2)
106
+
107
+ # Apply normalize, flip, gain, and device.
108
+ if normalize:
109
+ f /= f.sum()
110
+ if flip_filter:
111
+ f = f.flip(list(range(f.ndim)))
112
+ f = f * (gain ** (f.ndim / 2))
113
+ f = f.to(device=device)
114
+ return f
115
+
116
+ #----------------------------------------------------------------------------
117
+
118
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
119
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
120
+
121
+ Performs the following sequence of operations for each channel:
122
+
123
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
124
+
125
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
126
+ Negative padding corresponds to cropping the image.
127
+
128
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
129
+ so that the footprint of all output pixels lies within the input image.
130
+
131
+ 4. Downsample the image by keeping every Nth pixel (`down`).
132
+
133
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
134
+ The fused op is considerably more efficient than performing the same calculation
135
+ using standard PyTorch ops. It supports gradients of arbitrary order.
136
+
137
+ Args:
138
+ x: Float32/float64/float16 input tensor of the shape
139
+ `[batch_size, num_channels, in_height, in_width]`.
140
+ f: Float32 FIR filter of the shape
141
+ `[filter_height, filter_width]` (non-separable),
142
+ `[filter_taps]` (separable), or
143
+ `None` (identity).
144
+ up: Integer upsampling factor. Can be a single int or a list/tuple
145
+ `[x, y]` (default: 1).
146
+ down: Integer downsampling factor. Can be a single int or a list/tuple
147
+ `[x, y]` (default: 1).
148
+ padding: Padding with respect to the upsampled image. Can be a single number
149
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
150
+ (default: 0).
151
+ flip_filter: False = convolution, True = correlation (default: False).
152
+ gain: Overall scaling factor for signal magnitude (default: 1).
153
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
154
+
155
+ Returns:
156
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
157
+ """
158
+ assert isinstance(x, torch.Tensor)
159
+ assert impl in ['ref', 'cuda']
160
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
161
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
162
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
163
+
164
+ #----------------------------------------------------------------------------
165
+
166
+ @misc.profiled_function
167
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
168
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
169
+ """
170
+ # Validate arguments.
171
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
172
+ if f is None:
173
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
174
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
175
+ assert f.dtype == torch.float32 and not f.requires_grad
176
+ batch_size, num_channels, in_height, in_width = x.shape
177
+ upx, upy = _parse_scaling(up)
178
+ downx, downy = _parse_scaling(down)
179
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
180
+
181
+ # Check that upsampled buffer is not smaller than the filter.
182
+ upW = in_width * upx + padx0 + padx1
183
+ upH = in_height * upy + pady0 + pady1
184
+ assert upW >= f.shape[-1] and upH >= f.shape[0]
185
+
186
+ # Upsample by inserting zeros.
187
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
188
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
189
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
190
+
191
+ # Pad or crop.
192
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
193
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
194
+
195
+ # Setup filter.
196
+ f = f * (gain ** (f.ndim / 2))
197
+ f = f.to(x.dtype)
198
+ if not flip_filter:
199
+ f = f.flip(list(range(f.ndim)))
200
+
201
+ # Convolve with the filter.
202
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
203
+ if f.ndim == 4:
204
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
205
+ else:
206
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
207
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
208
+
209
+ # Downsample by throwing away pixels.
210
+ x = x[:, :, ::downy, ::downx]
211
+ return x
212
+
213
+ #----------------------------------------------------------------------------
214
+
215
+ _upfirdn2d_cuda_cache = dict()
216
+
217
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
218
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
219
+ """
220
+ # Parse arguments.
221
+ upx, upy = _parse_scaling(up)
222
+ downx, downy = _parse_scaling(down)
223
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
224
+
225
+ # Lookup from cache.
226
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
227
+ if key in _upfirdn2d_cuda_cache:
228
+ return _upfirdn2d_cuda_cache[key]
229
+
230
+ # Forward op.
231
+ class Upfirdn2dCuda(torch.autograd.Function):
232
+ @staticmethod
233
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
234
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
235
+ if f is None:
236
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
237
+ if f.ndim == 1 and f.shape[0] == 1:
238
+ f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
239
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
240
+ y = x
241
+ if f.ndim == 2:
242
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
243
+ else:
244
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
245
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
246
+ ctx.save_for_backward(f)
247
+ ctx.x_shape = x.shape
248
+ return y
249
+
250
+ @staticmethod
251
+ def backward(ctx, dy): # pylint: disable=arguments-differ
252
+ f, = ctx.saved_tensors
253
+ _, _, ih, iw = ctx.x_shape
254
+ _, _, oh, ow = dy.shape
255
+ fw, fh = _get_filter_size(f)
256
+ p = [
257
+ fw - padx0 - 1,
258
+ iw * upx - ow * downx + padx0 - upx + 1,
259
+ fh - pady0 - 1,
260
+ ih * upy - oh * downy + pady0 - upy + 1,
261
+ ]
262
+ dx = None
263
+ df = None
264
+
265
+ if ctx.needs_input_grad[0]:
266
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
267
+
268
+ assert not ctx.needs_input_grad[1]
269
+ return dx, df
270
+
271
+ # Add to cache.
272
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
273
+ return Upfirdn2dCuda
274
+
275
+ #----------------------------------------------------------------------------
276
+
277
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
278
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
279
+
280
+ By default, the result is padded so that its shape matches the input.
281
+ User-specified padding is applied on top of that, with negative values
282
+ indicating cropping. Pixels outside the image are assumed to be zero.
283
+
284
+ Args:
285
+ x: Float32/float64/float16 input tensor of the shape
286
+ `[batch_size, num_channels, in_height, in_width]`.
287
+ f: Float32 FIR filter of the shape
288
+ `[filter_height, filter_width]` (non-separable),
289
+ `[filter_taps]` (separable), or
290
+ `None` (identity).
291
+ padding: Padding with respect to the output. Can be a single number or a
292
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
293
+ (default: 0).
294
+ flip_filter: False = convolution, True = correlation (default: False).
295
+ gain: Overall scaling factor for signal magnitude (default: 1).
296
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
297
+
298
+ Returns:
299
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
300
+ """
301
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
302
+ fw, fh = _get_filter_size(f)
303
+ p = [
304
+ padx0 + fw // 2,
305
+ padx1 + (fw - 1) // 2,
306
+ pady0 + fh // 2,
307
+ pady1 + (fh - 1) // 2,
308
+ ]
309
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
310
+
311
+ #----------------------------------------------------------------------------
312
+
313
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
314
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
315
+
316
+ By default, the result is padded so that its shape is a multiple of the input.
317
+ User-specified padding is applied on top of that, with negative values
318
+ indicating cropping. Pixels outside the image are assumed to be zero.
319
+
320
+ Args:
321
+ x: Float32/float64/float16 input tensor of the shape
322
+ `[batch_size, num_channels, in_height, in_width]`.
323
+ f: Float32 FIR filter of the shape
324
+ `[filter_height, filter_width]` (non-separable),
325
+ `[filter_taps]` (separable), or
326
+ `None` (identity).
327
+ up: Integer upsampling factor. Can be a single int or a list/tuple
328
+ `[x, y]` (default: 1).
329
+ padding: Padding with respect to the output. Can be a single number or a
330
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
331
+ (default: 0).
332
+ flip_filter: False = convolution, True = correlation (default: False).
333
+ gain: Overall scaling factor for signal magnitude (default: 1).
334
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
335
+
336
+ Returns:
337
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
338
+ """
339
+ upx, upy = _parse_scaling(up)
340
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
341
+ fw, fh = _get_filter_size(f)
342
+ p = [
343
+ padx0 + (fw + upx - 1) // 2,
344
+ padx1 + (fw - upx) // 2,
345
+ pady0 + (fh + upy - 1) // 2,
346
+ pady1 + (fh - upy) // 2,
347
+ ]
348
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
349
+
350
+ #----------------------------------------------------------------------------
351
+
352
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
353
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
354
+
355
+ By default, the result is padded so that its shape is a fraction of the input.
356
+ User-specified padding is applied on top of that, with negative values
357
+ indicating cropping. Pixels outside the image are assumed to be zero.
358
+
359
+ Args:
360
+ x: Float32/float64/float16 input tensor of the shape
361
+ `[batch_size, num_channels, in_height, in_width]`.
362
+ f: Float32 FIR filter of the shape
363
+ `[filter_height, filter_width]` (non-separable),
364
+ `[filter_taps]` (separable), or
365
+ `None` (identity).
366
+ down: Integer downsampling factor. Can be a single int or a list/tuple
367
+ `[x, y]` (default: 1).
368
+ padding: Padding with respect to the input. Can be a single number or a
369
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
370
+ (default: 0).
371
+ flip_filter: False = convolution, True = correlation (default: False).
372
+ gain: Overall scaling factor for signal magnitude (default: 1).
373
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
374
+
375
+ Returns:
376
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
377
+ """
378
+ downx, downy = _parse_scaling(down)
379
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
380
+ fw, fh = _get_filter_size(f)
381
+ p = [
382
+ padx0 + (fw - downx + 1) // 2,
383
+ padx1 + (fw - downx) // 2,
384
+ pady0 + (fh - downy + 1) // 2,
385
+ pady1 + (fh - downy) // 2,
386
+ ]
387
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
388
+
389
+ #----------------------------------------------------------------------------
torch_utils/persistence.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Facilities for pickling Python code alongside other data.
10
+
11
+ The pickled code is automatically imported into a separate Python module
12
+ during unpickling. This way, any previously exported pickles will remain
13
+ usable even if the original code is no longer available, or if the current
14
+ version of the code is not consistent with what was originally pickled."""
15
+
16
+ import sys
17
+ import pickle
18
+ import io
19
+ import inspect
20
+ import copy
21
+ import uuid
22
+ import types
23
+ import dnnlib
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _version = 6 # internal version number
28
+ _decorators = set() # {decorator_class, ...}
29
+ _import_hooks = [] # [hook_function, ...]
30
+ _module_to_src_dict = dict() # {module: src, ...}
31
+ _src_to_module_dict = dict() # {src: module, ...}
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def persistent_class(orig_class):
36
+ r"""Class decorator that extends a given class to save its source code
37
+ when pickled.
38
+
39
+ Example:
40
+
41
+ from torch_utils import persistence
42
+
43
+ @persistence.persistent_class
44
+ class MyNetwork(torch.nn.Module):
45
+ def __init__(self, num_inputs, num_outputs):
46
+ super().__init__()
47
+ self.fc = MyLayer(num_inputs, num_outputs)
48
+ ...
49
+
50
+ @persistence.persistent_class
51
+ class MyLayer(torch.nn.Module):
52
+ ...
53
+
54
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
+ source code alongside other internal state (e.g., parameters, buffers,
56
+ and submodules). This way, any previously exported pickle will remain
57
+ usable even if the class definitions have been modified or are no
58
+ longer available.
59
+
60
+ The decorator saves the source code of the entire Python module
61
+ containing the decorated class. It does *not* save the source code of
62
+ any imported modules. Thus, the imported modules must be available
63
+ during unpickling, also including `torch_utils.persistence` itself.
64
+
65
+ It is ok to call functions defined in the same module from the
66
+ decorated class. However, if the decorated class depends on other
67
+ classes defined in the same module, they must be decorated as well.
68
+ This is illustrated in the above example in the case of `MyLayer`.
69
+
70
+ It is also possible to employ the decorator just-in-time before
71
+ calling the constructor. For example:
72
+
73
+ cls = MyLayer
74
+ if want_to_make_it_persistent:
75
+ cls = persistence.persistent_class(cls)
76
+ layer = cls(num_inputs, num_outputs)
77
+
78
+ As an additional feature, the decorator also keeps track of the
79
+ arguments that were used to construct each instance of the decorated
80
+ class. The arguments can be queried via `obj.init_args` and
81
+ `obj.init_kwargs`, and they are automatically pickled alongside other
82
+ object state. A typical use case is to first unpickle a previous
83
+ instance of a persistent class, and then upgrade it to use the latest
84
+ version of the source code:
85
+
86
+ with open('old_pickle.pkl', 'rb') as f:
87
+ old_net = pickle.load(f)
88
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90
+ """
91
+ assert isinstance(orig_class, type)
92
+ if is_persistent(orig_class):
93
+ return orig_class
94
+
95
+ assert orig_class.__module__ in sys.modules
96
+ orig_module = sys.modules[orig_class.__module__]
97
+ orig_module_src = _module_to_src(orig_module)
98
+
99
+ class Decorator(orig_class):
100
+ _orig_module_src = orig_module_src
101
+ _orig_class_name = orig_class.__name__
102
+
103
+ def __init__(self, *args, **kwargs):
104
+ super().__init__(*args, **kwargs)
105
+ self._init_args = copy.deepcopy(args)
106
+ self._init_kwargs = copy.deepcopy(kwargs)
107
+ assert orig_class.__name__ in orig_module.__dict__
108
+ _check_pickleable(self.__reduce__())
109
+
110
+ @property
111
+ def init_args(self):
112
+ return copy.deepcopy(self._init_args)
113
+
114
+ @property
115
+ def init_kwargs(self):
116
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117
+
118
+ def __reduce__(self):
119
+ fields = list(super().__reduce__())
120
+ fields += [None] * max(3 - len(fields), 0)
121
+ if fields[0] is not _reconstruct_persistent_obj:
122
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
124
+ fields[1] = (meta,) # reconstruct args
125
+ fields[2] = None # state dict
126
+ return tuple(fields)
127
+
128
+ Decorator.__name__ = orig_class.__name__
129
+ _decorators.add(Decorator)
130
+ return Decorator
131
+
132
+ #----------------------------------------------------------------------------
133
+
134
+ def is_persistent(obj):
135
+ r"""Test whether the given object or class is persistent, i.e.,
136
+ whether it will save its source code when pickled.
137
+ """
138
+ try:
139
+ if obj in _decorators:
140
+ return True
141
+ except TypeError:
142
+ pass
143
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144
+
145
+ #----------------------------------------------------------------------------
146
+
147
+ def import_hook(hook):
148
+ r"""Register an import hook that is called whenever a persistent object
149
+ is being unpickled. A typical use case is to patch the pickled source
150
+ code to avoid errors and inconsistencies when the API of some imported
151
+ module has changed.
152
+
153
+ The hook should have the following signature:
154
+
155
+ hook(meta) -> modified meta
156
+
157
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158
+
159
+ type: Type of the persistent object, e.g. `'class'`.
160
+ version: Internal version number of `torch_utils.persistence`.
161
+ module_src Original source code of the Python module.
162
+ class_name: Class name in the original Python module.
163
+ state: Internal state of the object.
164
+
165
+ Example:
166
+
167
+ @persistence.import_hook
168
+ def wreck_my_network(meta):
169
+ if meta.class_name == 'MyNetwork':
170
+ print('MyNetwork is being imported. I will wreck it!')
171
+ meta.module_src = meta.module_src.replace("True", "False")
172
+ return meta
173
+ """
174
+ assert callable(hook)
175
+ _import_hooks.append(hook)
176
+
177
+ #----------------------------------------------------------------------------
178
+
179
+ def _reconstruct_persistent_obj(meta):
180
+ r"""Hook that is called internally by the `pickle` module to unpickle
181
+ a persistent object.
182
+ """
183
+ meta = dnnlib.EasyDict(meta)
184
+ meta.state = dnnlib.EasyDict(meta.state)
185
+ for hook in _import_hooks:
186
+ meta = hook(meta)
187
+ assert meta is not None
188
+
189
+ assert meta.version == _version
190
+ module = _src_to_module(meta.module_src)
191
+
192
+ assert meta.type == 'class'
193
+ orig_class = module.__dict__[meta.class_name]
194
+ decorator_class = persistent_class(orig_class)
195
+ obj = decorator_class.__new__(decorator_class)
196
+
197
+ setstate = getattr(obj, '__setstate__', None)
198
+ if callable(setstate):
199
+ setstate(meta.state) # pylint: disable=not-callable
200
+ else:
201
+ obj.__dict__.update(meta.state)
202
+ return obj
203
+
204
+ #----------------------------------------------------------------------------
205
+
206
+ def _module_to_src(module):
207
+ r"""Query the source code of a given Python module.
208
+ """
209
+ src = _module_to_src_dict.get(module, None)
210
+ if src is None:
211
+ src = inspect.getsource(module)
212
+ _module_to_src_dict[module] = src
213
+ _src_to_module_dict[src] = module
214
+ return src
215
+
216
+ def _src_to_module(src):
217
+ r"""Get or create a Python module for the given source code.
218
+ """
219
+ module = _src_to_module_dict.get(src, None)
220
+ if module is None:
221
+ module_name = "_imported_module_" + uuid.uuid4().hex
222
+ module = types.ModuleType(module_name)
223
+ sys.modules[module_name] = module
224
+ _module_to_src_dict[module] = src
225
+ _src_to_module_dict[src] = module
226
+ exec(src, module.__dict__) # pylint: disable=exec-used
227
+ return module
228
+
229
+ #----------------------------------------------------------------------------
230
+
231
+ def _check_pickleable(obj):
232
+ r"""Check that the given object is pickleable, raising an exception if
233
+ it is not. This function is expected to be considerably more efficient
234
+ than actually pickling the object.
235
+ """
236
+ def recurse(obj):
237
+ if isinstance(obj, (list, tuple, set)):
238
+ return [recurse(x) for x in obj]
239
+ if isinstance(obj, dict):
240
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
241
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242
+ return None # Python primitive types are pickleable.
243
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
244
+ return None # NumPy arrays and PyTorch tensors are pickleable.
245
+ if is_persistent(obj):
246
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
247
+ return obj
248
+ with io.BytesIO() as f:
249
+ pickle.dump(recurse(obj), f)
250
+
251
+ #----------------------------------------------------------------------------