akhaliq HF staff commited on
Commit
8ca3a29
·
1 Parent(s): 68225a3
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. compute_direction.py +96 -0
  2. compute_jacobian.py +200 -0
  3. coordinate.py +142 -0
  4. directions/.DS_Store +0 -0
  5. directions/afhq/.DS_Store +0 -0
  6. directions/afhq/stylegan3/eyes-r.npy +3 -0
  7. directions/ffhq/stylegan2/eyebrows.npy +3 -0
  8. directions/ffhq/stylegan2/eyesize.npy +3 -0
  9. directions/ffhq/stylegan2/gaze_direction.npy +3 -0
  10. directions/ffhq/stylegan2/lipstick.npy +3 -0
  11. directions/ffhq/stylegan2/mouth.npy +3 -0
  12. directions/ffhq/stylegan2/nose_length.npy +3 -0
  13. directions/ffhq/stylegan3/eyes-r.npy +3 -0
  14. manipulate.py +253 -0
  15. models/__init__.py +45 -0
  16. models/ghfeat_encoder.py +563 -0
  17. models/inception_model.py +562 -0
  18. models/perceptual_model.py +519 -0
  19. models/pggan_discriminator.py +465 -0
  20. models/pggan_generator.py +401 -0
  21. models/stylegan2_discriminator.py +729 -0
  22. models/stylegan2_generator.py +1394 -0
  23. models/stylegan3_generator.py +1332 -0
  24. models/stylegan_discriminator.py +624 -0
  25. models/stylegan_generator.py +999 -0
  26. models/test.py +146 -0
  27. models/utils/__init__.py +0 -0
  28. models/utils/ops.py +18 -0
  29. requirements/convert.txt +11 -0
  30. requirements/develop.txt +3 -0
  31. requirements/minimal.txt +21 -0
  32. synthesis.py +178 -0
  33. third_party/__init__.py +0 -0
  34. third_party/stylegan2_official_ops/README.md +28 -0
  35. third_party/stylegan2_official_ops/__init__.py +0 -0
  36. third_party/stylegan2_official_ops/bias_act.cpp +99 -0
  37. third_party/stylegan2_official_ops/bias_act.cu +173 -0
  38. third_party/stylegan2_official_ops/bias_act.h +38 -0
  39. third_party/stylegan2_official_ops/bias_act.py +227 -0
  40. third_party/stylegan2_official_ops/conv2d_gradfix.py +189 -0
  41. third_party/stylegan2_official_ops/conv2d_resample.py +168 -0
  42. third_party/stylegan2_official_ops/custom_ops.py +159 -0
  43. third_party/stylegan2_official_ops/fma.py +73 -0
  44. third_party/stylegan2_official_ops/grid_sample_gradfix.py +98 -0
  45. third_party/stylegan2_official_ops/misc.py +281 -0
  46. third_party/stylegan2_official_ops/upfirdn2d.cpp +103 -0
  47. third_party/stylegan2_official_ops/upfirdn2d.cu +350 -0
  48. third_party/stylegan2_official_ops/upfirdn2d.h +59 -0
  49. third_party/stylegan2_official_ops/upfirdn2d.py +401 -0
  50. third_party/stylegan3_official_ops/README.md +30 -0
compute_direction.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Computes the semantic directions regarding a specific image region."""
3
+
4
+ import os
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from coordinate import COORDINATES
10
+ from coordinate import get_mask
11
+ from utils.image_utils import save_image
12
+
13
+
14
+ def parse_args():
15
+ """Parses arguments."""
16
+
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('jaco_path', type=str,
19
+ help='Path to jacobian matrix.')
20
+ parser.add_argument('--region', type=str, default='eyes',
21
+ help='The region to be used to compute jacobian.')
22
+ parser.add_argument('--save_dir', type=str, default='',
23
+ help='Directory to save the results. If not specified,'
24
+ 'the results will be saved to '
25
+ '`work_dirs/{TASK_SPECIFIC}/` by default')
26
+ parser.add_argument('--job', type=str, default='directions',
27
+ help='Name for the job (default: directions)')
28
+ parser.add_argument('--name', type=str, default='resefa',
29
+ help='Name of help save the results.')
30
+ parser.add_argument('--data_name', type=str, default='ffhq',
31
+ help='Name of the dataset.')
32
+ parser.add_argument('--full_rank', action='store_true',
33
+ help='Whether or not to full rank background'
34
+ ' (default: False).')
35
+ parser.add_argument('--tao', type=float, default=1e-3,
36
+ help='Coefficient to the identity matrix '
37
+ '(default: 1e-3).')
38
+ return parser.parse_args()
39
+
40
+
41
+ def main():
42
+ """Main function."""
43
+ args = parse_args()
44
+ assert os.path.exists(args.jaco_path)
45
+ Jacobians = np.load(args.jaco_path)
46
+ image_size = Jacobians.shape[2]
47
+ w_dim = Jacobians.shape[-1]
48
+ coord_dict = COORDINATES[args.data_name]
49
+ assert args.region in coord_dict, \
50
+ f'{args.region} coordinate is not defined in ' \
51
+ f'COORDINATE_{args.data_name}. Please define this region first!'
52
+ coords = coord_dict[args.region]
53
+ mask = get_mask(image_size, coordinate=coords)
54
+ foreground_ind = np.where(mask == 1)
55
+ background_ind = np.where((1 - mask) == 1)
56
+ temp_dir = f'./work_dirs/{args.job}/{args.data_name}/{args.region}'
57
+ save_dir = args.save_dir or temp_dir
58
+ os.makedirs(save_dir, exist_ok=True)
59
+ for ind in tqdm(range(Jacobians.shape[0])):
60
+ Jacobian = Jacobians[ind]
61
+ if len(Jacobian.shape) == 4: # [H, W, 1, latent_dim]
62
+ Jaco_fore = Jacobian[foreground_ind[0], foreground_ind[1], 0]
63
+ Jaco_back = Jacobian[background_ind[0], background_ind[1], 0]
64
+ elif len(Jacobian.shape) == 5: # [channel, H, W, 1, latent_dim]
65
+ Jaco_fore = Jacobian[:, foreground_ind[0], foreground_ind[1], 0]
66
+ Jaco_back = Jacobian[:, background_ind[0], background_ind[1], 0]
67
+ else:
68
+ raise ValueError('Shape of the Jacobian is not correct!')
69
+ Jaco_fore = np.reshape(Jaco_fore, [-1, w_dim])
70
+ Jaco_back = np.reshape(Jaco_back, [-1, w_dim])
71
+ coef_f = 1 / Jaco_fore.shape[0]
72
+ coef_b = 1 / Jaco_back.shape[0]
73
+ M_fore = coef_f * Jaco_fore.T.dot(Jaco_fore)
74
+ M_back = coef_b * Jaco_back.T.dot(Jaco_back)
75
+ if args.full_rank:
76
+ # J = J_b^TJ_b
77
+ # J = (J + tao * trace(J) * I)
78
+ print('Using full rank')
79
+ coef = args.tao * np.trace(M_back)
80
+ M_back = M_back + coef * np.identity(M_back.shape[0])
81
+ # inv(B) * A = lambda x
82
+ temp = np.linalg.inv(M_back).dot(M_fore)
83
+ eig_val, eig_vec = np.linalg.eig(temp)
84
+ eig_val = np.real(eig_val)
85
+ eig_vec = np.real(eig_vec)
86
+ directions = eig_vec.T
87
+ directions = directions[np.argsort(-eig_val)]
88
+ save_name = f'{save_dir}/image_{ind:02d}_region_{args.region}' \
89
+ f'_name_{args.name}'
90
+ np.save(f'{save_name}.npy', directions)
91
+ mask_i = np.tile(mask[:, :, np.newaxis], [1, 1, 3]) * 255
92
+ save_image(f'{save_name}_mask.png', mask_i.astype(np.uint8))
93
+
94
+
95
+ if __name__ == '__main__':
96
+ main()
compute_jacobian.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Functions to compute Jacobian based on pre-trained GAN generator.
3
+
4
+ Support StyleGAN2 or StyleGAN3
5
+ """
6
+
7
+ import os
8
+ import argparse
9
+ import warnings
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch.autograd.functional import jacobian
16
+ from models import build_model
17
+ from utils.image_utils import save_image
18
+ from utils.image_utils import postprocess_image
19
+ from utils.custom_utils import to_numpy
20
+
21
+
22
+ warnings.filterwarnings(action='ignore', category=UserWarning)
23
+
24
+
25
+ def parse_args():
26
+ """Parses arguments."""
27
+ parser = argparse.ArgumentParser()
28
+ group = parser.add_argument_group('General options.')
29
+ group.add_argument('weight_path', type=str,
30
+ help='Weight path to the pre-trained model.')
31
+ group.add_argument('--save_dir', type=str, default=None,
32
+ help='Directory to save the results. If not specified, '
33
+ 'the results will be saved to '
34
+ '`work_dirs/{TASK_SPECIFIC}/` by default.')
35
+ group.add_argument('--job', type=str, default='jacobians',
36
+ help='Name for the job (default: jacobians)')
37
+ group.add_argument('--seed', type=int, default=4,
38
+ help='Seed for sampling. (default: 4)')
39
+ group.add_argument('--nums', type=int, default=5,
40
+ help='Number of samples to synthesized. (default: 5)')
41
+ group.add_argument('--img_size', type=int, default=1024,
42
+ help='Size of the synthesized images. (default: 1024)')
43
+ group.add_argument('--w_dim', type=int, default=512,
44
+ help='Dimension of the latent w. (default: 512)')
45
+ group.add_argument('--save_jpg', action='store_false',
46
+ help='Whether to save the images used to compute '
47
+ 'jacobians. (default: True)')
48
+ group.add_argument('-d', '--data_name', type=str, default='ffhq',
49
+ help='Name of the datasets. (default: ffhq)')
50
+ group.add_argument('--latent_path', type=str, default='',
51
+ help='Path to the given latent codes. (default: None)')
52
+
53
+ group = parser.add_argument_group('StyleGAN2')
54
+ group.add_argument('--stylegan2', action='store_true',
55
+ help='Whether or not using StyleGAN2. (default: False)')
56
+ group.add_argument('--scale_stylegan2', type=float, default=1.0,
57
+ help='Scale for the number of channel fro stylegan2.')
58
+ group.add_argument('--randomize_noise', type=str, default='const',
59
+ help='Noise type when computing. (const or random)')
60
+
61
+ group = parser.add_argument_group('StyleGAN3')
62
+ group.add_argument('--stylegan3', action='store_true',
63
+ help='Whether or not using StyleGAN3. (default: False)')
64
+ group.add_argument('--cfg', type=str, default='T',
65
+ help='Config of the stylegan3 (T/R).')
66
+ group.add_argument('--scale_stylegan3r', type=float, default=2.0,
67
+ help='Scale for the number of channel for stylegan3 R.')
68
+ group.add_argument('--scale_stylegan3t', type=float, default=1.0,
69
+ help='Scale for the number of channel for stylegan3 T.')
70
+ group.add_argument('--tx', type=float, default=0,
71
+ help='Translate X-coordinate. (default: 0.0)')
72
+ group.add_argument('--ty', type=float, default=0,
73
+ help='Translate Y-coordinate. (default: 0.0)')
74
+ group.add_argument('--rotate', type=float, default=0,
75
+ help='Rotation angle in degrees. (default: 0)')
76
+
77
+ group = parser.add_argument_group('Jacobians')
78
+ group.add_argument('--b', type=float, default=1e-3,
79
+ help='Constant when computing jacobians fast.')
80
+ group.add_argument('--batch_size', type=int, default=4,
81
+ help='Batch size. (default: 4)')
82
+ return parser.parse_args()
83
+
84
+
85
+ def main():
86
+ """Main function."""
87
+ args = parse_args()
88
+ # Parse model configuration.
89
+ assert (args.stylegan2 and not args.stylegan3) or \
90
+ (not args.stylegan2 and args.stylegan3)
91
+ job_disc = ''
92
+ if args.stylegan2:
93
+ config = dict(model_type='StyleGAN2Generator',
94
+ resolution=args.img_size,
95
+ w_dim=args.w_dim,
96
+ fmaps_base=int(args.scale_stylegan2 * (32 << 10)),
97
+ fmaps_max=512,)
98
+ job_disc += 'stylegan2'
99
+ else:
100
+ if args.stylegan3 and args.cfg == 'R':
101
+ config = dict(model_type='StyleGAN3Generator',
102
+ resolution=args.img_size,
103
+ w_dim=args.w_dim,
104
+ fmaps_base=int(args.scale_stylegan3r * (32 << 10)),
105
+ fmaps_max=1024,
106
+ use_radial_filter=True,)
107
+ job_disc += 'stylegan3r'
108
+ elif args.stylegan3 and args.cfg == 'T':
109
+ config = dict(model_type='StyleGAN3Generator',
110
+ resolution=args.img_size,
111
+ w_dim=args.w_dim,
112
+ fmaps_base=int(args.scale_stylegan3t * (32 << 10)),
113
+ fmaps_max=512,
114
+ use_radial_filter=False,
115
+ kernel_size=3,)
116
+ job_disc += 'stylegan3t'
117
+ else:
118
+ raise TypeError(f'StyleGAN3 config type error, need `R/T`,'
119
+ f' but got {args.cfg}')
120
+ job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}'
121
+ temp_dir = f'work_dirs/{args.job}/{args.data_name}/{job_name}'
122
+ save_dir = args.save_dir or temp_dir
123
+ os.makedirs(save_dir, exist_ok=True)
124
+ if args.save_jpg:
125
+ os.makedirs(f'{save_dir}/images', exist_ok=True)
126
+
127
+ print('Building generator...')
128
+ generator = build_model(**config)
129
+ checkpoint_path = args.weight_path
130
+ print(f'Loading checkpoint from `{checkpoint_path}` ...')
131
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
132
+ if 'generator_smooth' in checkpoint:
133
+ generator.load_state_dict(checkpoint['generator_smooth'])
134
+ else:
135
+ generator.load_state_dict(checkpoint['generator'])
136
+ generator = generator.eval().cuda()
137
+ print('Finish loading checkpoint.')
138
+
139
+ # Set random seed.
140
+ np.random.seed(args.seed)
141
+ torch.manual_seed(args.seed)
142
+ if os.path.exists(args.latent_path):
143
+ latent_zs = np.load(args.latent_path)
144
+ latent_zs = latent_zs[:args.nums]
145
+ else:
146
+ latent_zs = np.random.randn(args.nums, generator.z_dim)
147
+ latent_zs = torch.from_numpy(latent_zs.astype(np.float32))
148
+ latent_zs = latent_zs.cuda()
149
+ with torch.no_grad():
150
+ latent_ws = generator.mapping(latent_zs)['w']
151
+ print(f'Shape of the latent w: {latent_ws.shape}')
152
+
153
+ def syn2jaco(w):
154
+ """Wrap the synthesized function to compute the Jacobian easily.
155
+
156
+ Basically, this function defines a generator that takes the input
157
+ from the W space and then synthesizes an image. If the image is
158
+ larger than 256, it will be resized to 256 to save the time and
159
+ storage.
160
+
161
+ Args:
162
+ w: latent code from the W space
163
+
164
+ Returns:
165
+ An image with the size of [1, 256, 256]
166
+ """
167
+ wp = w.unsqueeze(1).repeat((1, generator.num_layers, 1))
168
+ image = generator.synthesis(wp)['image']
169
+ if image.shape[-1] > 256:
170
+ scale = 256 / image.shape[-1]
171
+ image = F.interpolate(image, scale_factor=scale)
172
+ image = torch.sum(image, dim=1)
173
+ return image
174
+
175
+ jacobians = []
176
+ for idx in tqdm(range(latent_zs.shape[0])):
177
+ latent_w = latent_ws[idx:idx+1]
178
+ jac_i = jacobian(func=syn2jaco,
179
+ inputs=latent_w,
180
+ create_graph=False,
181
+ strict=False)
182
+ jacobians.append(jac_i)
183
+ if args.save_jpg:
184
+ wp = latent_w.unsqueeze(1).repeat((1, generator.num_layers, 1))
185
+ syn_outputs = generator.synthesis(wp)['image']
186
+ syn_outputs = to_numpy(syn_outputs)
187
+ images = postprocess_image(syn_outputs)
188
+ save_path = f'{save_dir}/images/{idx:06d}.jpg'
189
+ save_image(save_path, images[0])
190
+ jacobians = torch.cat(jacobians, dim=0)
191
+ jacobians = to_numpy(jacobians)
192
+ print(f'shape of the jacobian: {jacobians.shape}')
193
+ latent_ws = to_numpy(latent_ws)
194
+ np.save(f'{save_dir}/latent_codes.npy', latent_ws)
195
+ np.save(f'{save_dir}/jacobians_w.npy', jacobians)
196
+ print(f'Finish computing {args.nums} jacobians.')
197
+
198
+
199
+ if __name__ == '__main__':
200
+ main()
coordinate.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Utility functions to help define the region coordinates within an image."""
3
+
4
+ import os
5
+ from glob import glob
6
+ import argparse
7
+ import numpy as np
8
+ import cv2
9
+ from tqdm import tqdm
10
+ from utils.parsing_utils import parse_index
11
+
12
+
13
+ def get_mask_by_coordinates(image_size, coordinate):
14
+ """Get mask using the provided coordinates."""
15
+ mask = np.zeros([image_size, image_size], dtype=np.float32)
16
+ center_x, center_y = coordinate[0], coordinate[1]
17
+ crop_x, crop_y = coordinate[2], coordinate[3]
18
+ xx = center_x - crop_x // 2
19
+ yy = center_y - crop_y // 2
20
+ mask[xx:xx + crop_x, yy:yy + crop_y] = 1.
21
+ return mask
22
+
23
+
24
+ def get_mask_by_segmentation(seg_mask, label):
25
+ """Get the mask using the segmentation array and labels."""
26
+ zeros = np.zeros_like(seg_mask)
27
+ ones = np.ones_like(seg_mask)
28
+ mask = np.where(seg_mask == label, ones, zeros)
29
+ return mask
30
+
31
+
32
+ def get_mask(image_size, coordinate=None, seg_mask=None, labels='1'):
33
+ """Get mask using either the coordinate or the segmentation array."""
34
+ if coordinate is not None:
35
+ print('Using coordinate to get mask!')
36
+ mask = get_mask_by_coordinates(image_size, coordinate)
37
+ else:
38
+ print('Using segmentation to get the mask!')
39
+ print(f'Using label {labels}')
40
+ mask = np.zeros_like(seg_mask)
41
+ for label_ in labels:
42
+ mask += get_mask_by_segmentation(seg_mask, int(label_))
43
+ mask = np.clip(mask, a_min=0, a_max=1)
44
+
45
+ return mask
46
+
47
+
48
+ # For FFHQ [center_x, center_y, height, width]
49
+ # Those coordinates are suitable for both ffhq and metface.
50
+ COORDINATE_ffhq = {'left_eye': [120, 95, 20, 38],
51
+ 'right_eye': [120, 159, 20, 38],
52
+ 'eyes': [120, 128, 20, 115],
53
+ 'nose': [142, 131, 40, 46],
54
+ 'mouth': [184, 127, 30, 70],
55
+ 'chin': [217, 130, 42, 110],
56
+ 'eyebrow': [126, 105, 15, 118],
57
+ }
58
+
59
+
60
+ # For FFHQ unaligned
61
+ COORDINATE_ffhqu = {'eyesr2': [134, 116, 30, 115],
62
+ 'eyesr3': [64, 128, 26, 115],
63
+ 'eyest0': [70, 88, 30, 115],
64
+ 'eyest3': [108, 142, 26, 115],
65
+ }
66
+
67
+ # [center_x, center_y, height, width]
68
+ COORDINATE_biggan = {'center0': [120, 120, 80, 80],
69
+ 'center1': [120, 120, 130, 130],
70
+ 'center2': [120, 120, 200, 200],
71
+ 'left_side': [128, 64, 256, 128],
72
+ 'top_side': [64, 128, 128, 256],
73
+ 'head0': [89, 115, 49, 70],
74
+ 'head1': [93, 110, 48, 70]}
75
+
76
+
77
+ COORDINATES = {'ffhq': COORDINATE_ffhq,
78
+ 'ffhqu': COORDINATE_ffhqu,
79
+ 'biggan': COORDINATE_biggan
80
+ }
81
+
82
+
83
+ def parse_args():
84
+ """Parses arguments."""
85
+
86
+ parser = argparse.ArgumentParser()
87
+ parser.add_argument('--image_path', type=str, default='',
88
+ help='The path to the image.')
89
+ parser.add_argument('--mask_path', type=str, default='',
90
+ help='The path to the mask.')
91
+ parser.add_argument('--save_dir', type=str, default='',
92
+ help='The path to the image.')
93
+ parser.add_argument('--label', type=str, default=None,
94
+ help='The label number in the mask.')
95
+ parser.add_argument('--data', type=str, default='ffhq',
96
+ help='The name of the dataset to test.')
97
+ parser.add_argument('--num', type=int, default=0,
98
+ help='number of image to display.')
99
+ parser.add_argument('--img_type', type=str, default='jpeg',
100
+ help='Format of the image.')
101
+
102
+ return parser.parse_args()
103
+
104
+
105
+ def main():
106
+ """Main function to show an image with masks"""
107
+ args = parse_args()
108
+ save_dir = args.save_dir or './temp_mask'
109
+ os.makedirs(save_dir, exist_ok=True)
110
+ images = sorted(glob(f'{args.image_path}/*.{args.img_type}'))[args.num:]
111
+ label_files = sorted(glob(f'{args.mask_path}/*.npy'))[args.num:]
112
+ COORDINATE = COORDINATES[args.data]
113
+ for i, image in tqdm(enumerate(images)):
114
+ img = cv2.imread(image)
115
+ im_name = image.split('/')[-1].split('.')[0]
116
+ if args.label is None:
117
+ for name, coord in COORDINATE.items():
118
+ if len(coord) == 0:
119
+ continue
120
+ mask = np.zeros(img.shape, dtype=np.float32)
121
+ center_x, center_y = coord[0], coord[1]
122
+ crop_x, crop_y = coord[2], coord[3]
123
+ xx = center_x - crop_x // 2
124
+ yy = center_y - crop_y // 2
125
+ mask[xx:xx + crop_x, yy:yy + crop_y, :] = 1.
126
+ img_ = img * mask
127
+ cv2.imwrite(f'{save_dir}/{im_name}_{name}.png', img_)
128
+ else:
129
+ print('Using segmentation to get the mask!')
130
+ seg_mask = np.load(label_files[i])
131
+ labels = parse_index(args.label)
132
+ print(f'Using label {labels}')
133
+ mask = np.zeros_like(seg_mask)
134
+ for label_ in labels:
135
+ mask += get_mask_by_segmentation(seg_mask, int(label_))
136
+ mask = np.clip(mask, a_min=0, a_max=1)
137
+ img_ = img * mask[:, :, np.newaxis]
138
+ cv2.imwrite(f'{save_dir}/{im_name}_{args.label}.png', img_)
139
+
140
+
141
+ if __name__ == '__main__':
142
+ main()
directions/.DS_Store ADDED
Binary file (6.15 kB). View file
 
directions/afhq/.DS_Store ADDED
Binary file (6.15 kB). View file
 
directions/afhq/stylegan3/eyes-r.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:959982185cd0401b8ab984ad2c1c22c01a494d8928a765e353c4fc34e9b079a2
3
+ size 2176
directions/ffhq/stylegan2/eyebrows.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8a7385ee951fa93b34044a768254a937327c87b0d86a569b62ae4593ae0b765
3
+ size 2176
directions/ffhq/stylegan2/eyesize.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74d417489944e2f2c6ee9622c4a3745edf20edad5837f5d4b2d89847003d0ec5
3
+ size 2176
directions/ffhq/stylegan2/gaze_direction.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98c7787363716096706db4f6cc039cf5ffc773142a0f790ed185300ddaf8fc0e
3
+ size 2176
directions/ffhq/stylegan2/lipstick.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ce301bb70f35e10c06b9d6fbb485644a65ebf1ac57a080e70fc3448b3343dc6
3
+ size 2176
directions/ffhq/stylegan2/mouth.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cf4fe4565fd2d3f363d0b2e7de1073f1379c8488113b99f8d3bad9ba53a0fa4
3
+ size 2176
directions/ffhq/stylegan2/nose_length.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a7d3685b38f18e87a49cc3a79106f18dbdf30b1e8d3db5df21c84d592566ea5
3
+ size 2176
directions/ffhq/stylegan3/eyes-r.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91f7c4432a030568b1aac414c30e7f00cf4d441228a625eb8d971f5f28b036db
3
+ size 2176
manipulate.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Manipulates synthesized or real images with existing boundary.
3
+
4
+ Support StyleGAN2 and StyleGAN3.
5
+ """
6
+
7
+ import os.path
8
+ import argparse
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import torch
12
+
13
+ from models import build_model
14
+ from utils.visualizers.html_visualizer import HtmlVisualizer
15
+ from utils.image_utils import save_image
16
+ from utils.parsing_utils import parse_index
17
+ from utils.image_utils import postprocess_image
18
+ from utils.custom_utils import to_numpy, linear_interpolate
19
+ from utils.custom_utils import make_transform
20
+
21
+
22
+ def parse_args():
23
+ """Parses arguments."""
24
+ parser = argparse.ArgumentParser()
25
+ group = parser.add_argument_group('General options.')
26
+ group.add_argument('weight_path', type=str,
27
+ help='Weight path to the pre-trained model.')
28
+ group.add_argument('boundary_path', type=str,
29
+ help='Path to the attribute vectors.')
30
+ group.add_argument('--save_dir', type=str, default=None,
31
+ help='Directory to save the results. If not specified, '
32
+ 'the results will be saved to '
33
+ '`work_dirs/{TASK_SPECIFIC}/` by default.')
34
+ group.add_argument('--job', type=str, default='manipulations',
35
+ help='Name for the job. (default: manipulations)')
36
+ group.add_argument('--seed', type=int, default=4,
37
+ help='Seed for sampling. (default: 4)')
38
+ group.add_argument('--nums', type=int, default=10,
39
+ help='Number of samples to synthesized. (default: 10)')
40
+ group.add_argument('--img_size', type=int, default=1024,
41
+ help='Size of the synthesized images. (default: 1024)')
42
+ group.add_argument('--vis_size', type=int, default=256,
43
+ help='Size of the visualize images. (default: 256)')
44
+ group.add_argument('--w_dim', type=int, default=512,
45
+ help='Dimension of the latent w. (default: 512)')
46
+ group.add_argument('--batch_size', type=int, default=4,
47
+ help='Batch size. (default: 4)')
48
+ group.add_argument('--save_jpg', action='store_true', default=False,
49
+ help='Whether to save raw image. (default: False)')
50
+ group.add_argument('-d', '--data_name', type=str, default='ffhq',
51
+ help='Name of the datasets. (default: ffhq)')
52
+ group.add_argument('--latent_path', type=str, default='',
53
+ help='Path to the given latent codes. (default: None)')
54
+ group.add_argument('--trunc_psi', type=float, default=0.7,
55
+ help='Psi factor used for truncation. (default: 0.7)')
56
+ group.add_argument('--trunc_layers', type=int, default=8,
57
+ help='Number of layers to perform truncation.'
58
+ ' (default: 8)')
59
+ group.add_argument('--name', type=str, default='resefa',
60
+ help='Name of help save the results.')
61
+
62
+ group = parser.add_argument_group('StyleGAN2')
63
+ group.add_argument('--stylegan2', action='store_true',
64
+ help='Whether or not using StyleGAN2. (default: False)')
65
+ group.add_argument('--scale_stylegan2', type=float, default=1.0,
66
+ help='Scale for the number of channel fro stylegan2.')
67
+ group.add_argument('--randomize_noise', type=str, default='const',
68
+ help='Noise type when editing. (const or random)')
69
+
70
+ group = parser.add_argument_group('StyleGAN3')
71
+ group.add_argument('--stylegan3', action='store_true',
72
+ help='Whether or not using StyleGAN3. (default: False)')
73
+ group.add_argument('--cfg', type=str, default='T',
74
+ help='Config of the stylegan3 (T/R)')
75
+ group.add_argument('--scale_stylegan3r', type=float, default=2.0,
76
+ help='Scale for the number of channel for stylegan3 R.')
77
+ group.add_argument('--scale_stylegan3t', type=float, default=1.0,
78
+ help='Scale for the number of channel for stylegan3 T.')
79
+ group.add_argument('--tx', type=float, default=0,
80
+ help='Translate X-coordinate. (default: 0.0)')
81
+ group.add_argument('--ty', type=float, default=0,
82
+ help='Translate Y-coordinate. (default: 0.0)')
83
+ group.add_argument('--rotate', type=float, default=0,
84
+ help='Rotation angle in degrees. (default: 0)')
85
+
86
+ group = parser.add_argument_group('Manipulation')
87
+ group.add_argument('--mani_layers', type=str, default='4,5,6,7',
88
+ help='The layers will be manipulated.'
89
+ '(default: 4,5,6,7). For the eyebrow and lipstick,'
90
+ 'using [8-11] layers instead.')
91
+ group.add_argument('--step', type=int, default=7,
92
+ help='Number of manipulation steps. (default: 7)')
93
+ group.add_argument('--start', type=int, default=0,
94
+ help='The start index of the manipulation directions.')
95
+ group.add_argument('--end', type=int, default=1,
96
+ help='The end index of the manipulation directions.')
97
+ group.add_argument('--start_distance', type=float, default=-10.0,
98
+ help='Start distance for manipulation. (default: -10.0)')
99
+ group.add_argument('--end_distance', type=float, default=10.0,
100
+ help='End distance for manipulation. (default: 10.0)')
101
+
102
+ return parser.parse_args()
103
+
104
+
105
+ def main():
106
+ """Main function."""
107
+ args = parse_args()
108
+ # Parse model configuration.
109
+ assert (args.stylegan2 and not args.stylegan3) or \
110
+ (not args.stylegan2 and args.stylegan3)
111
+ checkpoint_path = args.weight_path
112
+ boundary_path = args.boundary_path
113
+ assert os.path.exists(checkpoint_path)
114
+ assert os.path.exists(boundary_path)
115
+ boundary_name = os.path.splitext(os.path.basename(boundary_path))[0]
116
+ job_disc = ''
117
+ if args.stylegan2:
118
+ config = dict(model_type='StyleGAN2Generator',
119
+ resolution=args.img_size,
120
+ w_dim=args.w_dim,
121
+ fmaps_base=int(args.scale_stylegan2 * (32 << 10)),
122
+ fmaps_max=512,)
123
+ job_disc += 'stylegan2'
124
+ else:
125
+ if args.stylegan3 and args.cfg == 'R':
126
+ config = dict(model_type='StyleGAN3Generator',
127
+ resolution=args.img_size,
128
+ w_dim=args.w_dim,
129
+ fmaps_base=int(args.scale_stylegan3r * (32 << 10)),
130
+ fmaps_max=1024,
131
+ use_radial_filter=True,)
132
+ job_disc += 'stylegan3r'
133
+ elif args.stylegan3 and args.cfg == 'T':
134
+ config = dict(model_type='StyleGAN3Generator',
135
+ resolution=args.img_size,
136
+ w_dim=args.w_dim,
137
+ fmaps_base=int(args.scale_stylegan3t * (32 << 10)),
138
+ fmaps_max=512,
139
+ use_radial_filter=False,
140
+ kernel_size=3,)
141
+ job_disc += 'stylegan3t'
142
+ else:
143
+ raise TypeError(f'StyleGAN3 config type error, need `R/T`,'
144
+ f' but got {args.cfg} instead.')
145
+
146
+ # Get work directory and job name.
147
+ save_dir = args.save_dir or f'work_dirs/{args.job}/{args.data_name}'
148
+ os.makedirs(save_dir, exist_ok=True)
149
+ job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}_{boundary_name}'
150
+ os.makedirs(f'{save_dir}/{job_name}', exist_ok=True)
151
+
152
+ print('Building generator...')
153
+ generator = build_model(**config)
154
+ print(f'Loading checkpoint from `{checkpoint_path}` ...')
155
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
156
+ if 'generator_smooth' in checkpoint:
157
+ generator.load_state_dict(checkpoint['generator_smooth'])
158
+ else:
159
+ generator.load_state_dict(checkpoint['generator'])
160
+ generator = generator.eval().cuda()
161
+ print('Finish loading checkpoint.')
162
+ if args.stylegan3 and hasattr(generator.synthesis, 'early_layer'):
163
+ m = make_transform(args.tx, args.ty, args.rotate)
164
+ m = np.linalg.inv(m)
165
+ generator.synthesis.early_layer.transform.copy_(torch.from_numpy(m))
166
+
167
+ np.random.seed(args.seed)
168
+ torch.manual_seed(args.seed)
169
+ if os.path.exists(args.latent_path):
170
+ print(f'Load latent codes from {args.latent_path}')
171
+ latent_zs = np.load(args.latent_path)
172
+ latent_zs = latent_zs[:args.nums]
173
+ else:
174
+ print('Sampling latent code randomly')
175
+ latent_zs = np.random.randn(args.nums, generator.z_dim)
176
+ latent_zs = torch.from_numpy(latent_zs.astype(np.float32))
177
+ latent_zs = latent_zs.cuda()
178
+ num_images = latent_zs.shape[0]
179
+ wp = []
180
+ for idx in range(0, num_images, args.batch_size):
181
+ latent_z = latent_zs[idx:idx+args.batch_size]
182
+ latent_w_ = generator.mapping(latent_z, None)['wp']
183
+ wp.append(latent_w_)
184
+ wp = torch.cat(wp, dim=0)
185
+ trunc_psi = args.trunc_psi
186
+ trunc_layers = args.trunc_layers
187
+ if trunc_psi < 1.0 and trunc_layers > 0:
188
+ w_avg = generator.w_avg
189
+ w_avg = w_avg.reshape(1, -1, generator.w_dim)[:, :trunc_layers]
190
+ wp[:, :trunc_layers] = w_avg.lerp(wp[:, :trunc_layers], trunc_psi)
191
+ print(f'Shape of the latent ws: {wp.shape}')
192
+ image_list = []
193
+ for i in range(num_images):
194
+ image_list.append(f'{i:06d}')
195
+
196
+ print('Loading boundary.')
197
+ directions = np.load(boundary_path)
198
+ layer_index = parse_index(args.mani_layers)
199
+ if not layer_index:
200
+ layer_index = list(range(generator.num_layers - 1))
201
+ print(f'Manipulating on layers `{layer_index}`.')
202
+
203
+ vis_size = None if args.vis_size == 0 else args.vis_size
204
+ delta_num = args.end - args.start
205
+ visualizer = HtmlVisualizer(num_rows=num_images * delta_num,
206
+ num_cols=args.step + 2,
207
+ image_size=vis_size)
208
+ visualizer.set_headers(
209
+ ['Name', 'Origin'] +
210
+ [f'Step {i:02d}' for i in range(1, args.step + 1)]
211
+ )
212
+ # Manipulate images.
213
+ print('Start manipulation.')
214
+ for row in tqdm(range(num_images)):
215
+ latent_w = wp[row:row+1]
216
+ images_ori = generator.synthesis(latent_w)['image']
217
+ images_ori = postprocess_image(to_numpy(images_ori))
218
+ if args.save_jpg:
219
+ save_image(f'{save_dir}/{job_name}/{row:06d}_orin.jpg',
220
+ images_ori[0])
221
+ for num_direc in range(args.start, args.end):
222
+ html_row = num_direc - args.start
223
+ direction = directions[num_direc:num_direc+1]
224
+ direction = np.tile(direction, [1, generator.num_layers, 1])
225
+ visualizer.set_cell(row * delta_num + html_row, 0,
226
+ text=f'{image_list[row]}_{num_direc:03d}')
227
+ visualizer.set_cell(row * delta_num + html_row, 1,
228
+ image=images_ori[0])
229
+ mani_codes = linear_interpolate(latent_code=to_numpy(latent_w),
230
+ boundary=direction,
231
+ layer_index=layer_index,
232
+ start_distance=args.start_distance,
233
+ end_distance=args.end_distance,
234
+ steps=args.step)
235
+ mani_codes = torch.from_numpy(mani_codes.astype(np.float32)).cuda()
236
+ for idx in range(0, mani_codes.shape[0], args.batch_size):
237
+ codes_ = mani_codes[idx:idx+args.batch_size]
238
+ images_ = generator.synthesis(codes_)['image']
239
+ images_ = postprocess_image(to_numpy(images_))
240
+ for i in range(images_.shape[0]):
241
+ visualizer.set_cell(row * delta_num + html_row, idx+i+2,
242
+ image=images_[i])
243
+ if args.save_jpg:
244
+ save_image(f'{save_dir}/{job_name}/{row:06d}_ind_'
245
+ f'{num_direc:06d}_mani_{idx+i:06d}.jpg',
246
+ images_[i])
247
+ # Save results.
248
+ np.save(f'{save_dir}/{job_name}/latent_codes.npy', to_numpy(wp))
249
+ visualizer.save(f'{save_dir}/{job_name}_{args.name}.html')
250
+
251
+
252
+ if __name__ == '__main__':
253
+ main()
models/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Collects all models."""
3
+
4
+ from .pggan_generator import PGGANGenerator
5
+ from .pggan_discriminator import PGGANDiscriminator
6
+ from .stylegan_generator import StyleGANGenerator
7
+ from .stylegan_discriminator import StyleGANDiscriminator
8
+ from .stylegan2_generator import StyleGAN2Generator
9
+ from .stylegan2_discriminator import StyleGAN2Discriminator
10
+ from .stylegan3_generator import StyleGAN3Generator
11
+ from .ghfeat_encoder import GHFeatEncoder
12
+ from .perceptual_model import PerceptualModel
13
+ from .inception_model import InceptionModel
14
+
15
+ __all__ = ['build_model']
16
+
17
+ _MODELS = {
18
+ 'PGGANGenerator': PGGANGenerator,
19
+ 'PGGANDiscriminator': PGGANDiscriminator,
20
+ 'StyleGANGenerator': StyleGANGenerator,
21
+ 'StyleGANDiscriminator': StyleGANDiscriminator,
22
+ 'StyleGAN2Generator': StyleGAN2Generator,
23
+ 'StyleGAN2Discriminator': StyleGAN2Discriminator,
24
+ 'StyleGAN3Generator': StyleGAN3Generator,
25
+ 'GHFeatEncoder': GHFeatEncoder,
26
+ 'PerceptualModel': PerceptualModel.build_model,
27
+ 'InceptionModel': InceptionModel.build_model
28
+ }
29
+
30
+
31
+ def build_model(model_type, **kwargs):
32
+ """Builds a model based on its class type.
33
+
34
+ Args:
35
+ model_type: Class type to which the model belongs, which is case
36
+ sensitive.
37
+ **kwargs: Additional arguments to build the model.
38
+
39
+ Raises:
40
+ ValueError: If the `model_type` is not supported.
41
+ """
42
+ if model_type not in _MODELS:
43
+ raise ValueError(f'Invalid model type: `{model_type}`!\n'
44
+ f'Types allowed: {list(_MODELS)}.')
45
+ return _MODELS[model_type](**kwargs)
models/ghfeat_encoder.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of encoder used in GH-Feat (including IDInvert).
3
+
4
+ ResNet is used as the backbone.
5
+
6
+ GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf
7
+ IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf
8
+
9
+ NOTE: Please use `latent_num` and `num_latents_per_head` to control the
10
+ inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert.
11
+ In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default.
12
+ """
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import torch.distributed as dist
20
+
21
+ __all__ = ['GHFeatEncoder']
22
+
23
+ # Resolutions allowed.
24
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
25
+
26
+ # pylint: disable=missing-function-docstring
27
+
28
+ class BasicBlock(nn.Module):
29
+ """Implementation of ResNet BasicBlock."""
30
+
31
+ expansion = 1
32
+
33
+ def __init__(self,
34
+ inplanes,
35
+ planes,
36
+ base_width=64,
37
+ stride=1,
38
+ groups=1,
39
+ dilation=1,
40
+ norm_layer=None,
41
+ downsample=None):
42
+ super().__init__()
43
+ if base_width != 64:
44
+ raise ValueError(f'BasicBlock of ResNet only supports '
45
+ f'`base_width=64`, but {base_width} received!')
46
+ if stride not in [1, 2]:
47
+ raise ValueError(f'BasicBlock of ResNet only supports `stride=1` '
48
+ f'and `stride=2`, but {stride} received!')
49
+ if groups != 1:
50
+ raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, '
51
+ f'but {groups} received!')
52
+ if dilation != 1:
53
+ raise ValueError(f'BasicBlock of ResNet only supports '
54
+ f'`dilation=1`, but {dilation} received!')
55
+ assert self.expansion == 1
56
+
57
+ self.stride = stride
58
+ if norm_layer is None:
59
+ norm_layer = nn.BatchNorm2d
60
+ self.conv1 = nn.Conv2d(in_channels=inplanes,
61
+ out_channels=planes,
62
+ kernel_size=3,
63
+ stride=stride,
64
+ padding=1,
65
+ groups=1,
66
+ dilation=1,
67
+ bias=False)
68
+ self.bn1 = norm_layer(planes)
69
+ self.relu = nn.ReLU(inplace=True)
70
+ self.conv2 = nn.Conv2d(in_channels=planes,
71
+ out_channels=planes,
72
+ kernel_size=3,
73
+ stride=1,
74
+ padding=1,
75
+ groups=1,
76
+ dilation=1,
77
+ bias=False)
78
+ self.bn2 = norm_layer(planes)
79
+ self.downsample = downsample
80
+
81
+ def forward(self, x):
82
+ identity = self.downsample(x) if self.downsample is not None else x
83
+
84
+ out = self.conv1(x)
85
+ out = self.bn1(out)
86
+ out = self.relu(out)
87
+
88
+ out = self.conv2(out)
89
+ out = self.bn2(out)
90
+ out = self.relu(out + identity)
91
+
92
+ return out
93
+
94
+
95
+ class Bottleneck(nn.Module):
96
+ """Implementation of ResNet Bottleneck."""
97
+
98
+ expansion = 4
99
+
100
+ def __init__(self,
101
+ inplanes,
102
+ planes,
103
+ base_width=64,
104
+ stride=1,
105
+ groups=1,
106
+ dilation=1,
107
+ norm_layer=None,
108
+ downsample=None):
109
+ super().__init__()
110
+ if stride not in [1, 2]:
111
+ raise ValueError(f'Bottleneck of ResNet only supports `stride=1` '
112
+ f'and `stride=2`, but {stride} received!')
113
+
114
+ width = int(planes * (base_width / 64)) * groups
115
+ self.stride = stride
116
+ if norm_layer is None:
117
+ norm_layer = nn.BatchNorm2d
118
+ self.conv1 = nn.Conv2d(in_channels=inplanes,
119
+ out_channels=width,
120
+ kernel_size=1,
121
+ stride=1,
122
+ padding=0,
123
+ dilation=1,
124
+ groups=1,
125
+ bias=False)
126
+ self.bn1 = norm_layer(width)
127
+ self.conv2 = nn.Conv2d(in_channels=width,
128
+ out_channels=width,
129
+ kernel_size=3,
130
+ stride=stride,
131
+ padding=dilation,
132
+ groups=groups,
133
+ dilation=dilation,
134
+ bias=False)
135
+ self.bn2 = norm_layer(width)
136
+ self.conv3 = nn.Conv2d(in_channels=width,
137
+ out_channels=planes * self.expansion,
138
+ kernel_size=1,
139
+ stride=1,
140
+ padding=0,
141
+ dilation=1,
142
+ groups=1,
143
+ bias=False)
144
+ self.bn3 = norm_layer(planes * self.expansion)
145
+ self.relu = nn.ReLU(inplace=True)
146
+ self.downsample = downsample
147
+
148
+ def forward(self, x):
149
+ identity = self.downsample(x) if self.downsample is not None else x
150
+
151
+ out = self.conv1(x)
152
+ out = self.bn1(out)
153
+ out = self.relu(out)
154
+
155
+ out = self.conv2(out)
156
+ out = self.bn2(out)
157
+ out = self.relu(out)
158
+
159
+ out = self.conv3(out)
160
+ out = self.bn3(out)
161
+ out = self.relu(out + identity)
162
+
163
+ return out
164
+
165
+
166
+ class GHFeatEncoder(nn.Module):
167
+ """Define the ResNet-based encoder network for GAN inversion.
168
+
169
+ On top of the backbone, there are several task-heads to produce inverted
170
+ codes. Please use `latent_dim` and `num_latents_per_head` to define the
171
+ structure. For example, `latent_dim = [512] * 14` and
172
+ `num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with
173
+ 14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers,
174
+ respectively) are used.
175
+
176
+ Settings for the encoder network:
177
+
178
+ (1) resolution: The resolution of the output image.
179
+ (2) latent_dim: Dimension of the latent space. A number (one code will be
180
+ produced), or a list of numbers regarding layer-wise latent codes.
181
+ (3) num_latents_per_head: Number of latents that is produced by each head.
182
+ (4) image_channels: Number of channels of the output image. (default: 3)
183
+ (5) final_res: Final resolution of the convolutional layers. (default: 4)
184
+
185
+ ResNet-related settings:
186
+
187
+ (1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18)
188
+ (2) inplanes: Number of channels of the first convolutional layer.
189
+ (default: 64)
190
+ (3) groups: Groups of the convolution, used in ResNet. (default: 1)
191
+ (4) width_per_group: Number of channels per group, used in ResNet.
192
+ (default: 64)
193
+ (5) replace_stride_with_dilation: Whether to replace stride with dilation,
194
+ used in ResNet. (default: None)
195
+ (6) norm_layer: Normalization layer used in the encoder. If set as `None`,
196
+ `nn.BatchNorm2d` will be used. Also, please NOTE that when using batch
197
+ normalization, the batch size is required to be larger than one for
198
+ training. (default: nn.BatchNorm2d)
199
+ (7) max_channels: Maximum number of channels in each layer. (default: 512)
200
+
201
+ Task-head related settings:
202
+
203
+ (1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting
204
+ the latent code. (default: True)
205
+ (2) fpn_channels: Number of channels used in FPN. (default: 512)
206
+ (3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting
207
+ the latent code. (default: True)
208
+ (4) sam_channels: Number of channels used in SAM. (default: 512)
209
+ """
210
+
211
+ arch_settings = {
212
+ 18: (BasicBlock, [2, 2, 2, 2]),
213
+ 34: (BasicBlock, [3, 4, 6, 3]),
214
+ 50: (Bottleneck, [3, 4, 6, 3]),
215
+ 101: (Bottleneck, [3, 4, 23, 3]),
216
+ 152: (Bottleneck, [3, 8, 36, 3])
217
+ }
218
+
219
+ def __init__(self,
220
+ resolution,
221
+ latent_dim,
222
+ num_latents_per_head,
223
+ image_channels=3,
224
+ final_res=4,
225
+ network_depth=18,
226
+ inplanes=64,
227
+ groups=1,
228
+ width_per_group=64,
229
+ replace_stride_with_dilation=None,
230
+ norm_layer=nn.BatchNorm2d,
231
+ max_channels=512,
232
+ use_fpn=True,
233
+ fpn_channels=512,
234
+ use_sam=True,
235
+ sam_channels=512):
236
+ super().__init__()
237
+
238
+ if resolution not in _RESOLUTIONS_ALLOWED:
239
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
240
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
241
+ if network_depth not in self.arch_settings:
242
+ raise ValueError(f'Invalid network depth: `{network_depth}`!\n'
243
+ f'Options allowed: '
244
+ f'{list(self.arch_settings.keys())}.')
245
+ if isinstance(latent_dim, int):
246
+ latent_dim = [latent_dim]
247
+ assert isinstance(latent_dim, (list, tuple))
248
+ assert isinstance(num_latents_per_head, (list, tuple))
249
+ assert sum(num_latents_per_head) == len(latent_dim)
250
+
251
+ self.resolution = resolution
252
+ self.latent_dim = latent_dim
253
+ self.num_latents_per_head = num_latents_per_head
254
+ self.num_heads = len(self.num_latents_per_head)
255
+ self.image_channels = image_channels
256
+ self.final_res = final_res
257
+ self.inplanes = inplanes
258
+ self.network_depth = network_depth
259
+ self.groups = groups
260
+ self.dilation = 1
261
+ self.base_width = width_per_group
262
+ self.replace_stride_with_dilation = replace_stride_with_dilation
263
+ if norm_layer is None:
264
+ norm_layer = nn.BatchNorm2d
265
+ if norm_layer == nn.BatchNorm2d and dist.is_initialized():
266
+ norm_layer = nn.SyncBatchNorm
267
+ self.norm_layer = norm_layer
268
+ self.max_channels = max_channels
269
+ self.use_fpn = use_fpn
270
+ self.fpn_channels = fpn_channels
271
+ self.use_sam = use_sam
272
+ self.sam_channels = sam_channels
273
+
274
+ block_fn, num_blocks_per_stage = self.arch_settings[network_depth]
275
+
276
+ self.num_stages = int(np.log2(resolution // final_res)) - 1
277
+ # Add one block for additional stages.
278
+ for i in range(len(num_blocks_per_stage), self.num_stages):
279
+ num_blocks_per_stage.append(1)
280
+ if replace_stride_with_dilation is None:
281
+ replace_stride_with_dilation = [False] * self.num_stages
282
+
283
+ # Backbone.
284
+ self.conv1 = nn.Conv2d(in_channels=self.image_channels,
285
+ out_channels=self.inplanes,
286
+ kernel_size=7,
287
+ stride=2,
288
+ padding=3,
289
+ bias=False)
290
+ self.bn1 = norm_layer(self.inplanes)
291
+ self.relu = nn.ReLU(inplace=True)
292
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
293
+
294
+ self.stage_channels = [self.inplanes]
295
+ self.stages = nn.ModuleList()
296
+ for i in range(self.num_stages):
297
+ inplanes = self.inplanes if i == 0 else planes * block_fn.expansion
298
+ planes = min(self.max_channels, self.inplanes * (2 ** i))
299
+ num_blocks = num_blocks_per_stage[i]
300
+ stride = 1 if i == 0 else 2
301
+ dilate = replace_stride_with_dilation[i]
302
+ self.stages.append(self._make_stage(block_fn=block_fn,
303
+ inplanes=inplanes,
304
+ planes=planes,
305
+ num_blocks=num_blocks,
306
+ stride=stride,
307
+ dilate=dilate))
308
+ self.stage_channels.append(planes * block_fn.expansion)
309
+
310
+ if self.num_heads > len(self.stage_channels):
311
+ raise ValueError('Number of task heads is larger than number of '
312
+ 'stages! Please reduce the number of heads.')
313
+
314
+ # Task-head.
315
+ if self.num_heads == 1:
316
+ self.use_fpn = False
317
+ self.use_sam = False
318
+
319
+ if self.use_fpn:
320
+ fpn_pyramid_channels = self.stage_channels[-self.num_heads:]
321
+ self.fpn = FPN(pyramid_channels=fpn_pyramid_channels,
322
+ out_channels=self.fpn_channels)
323
+ if self.use_sam:
324
+ if self.use_fpn:
325
+ sam_pyramid_channels = [self.fpn_channels] * self.num_heads
326
+ else:
327
+ sam_pyramid_channels = self.stage_channels[-self.num_heads:]
328
+ self.sam = SAM(pyramid_channels=sam_pyramid_channels,
329
+ out_channels=self.sam_channels)
330
+
331
+ self.heads = nn.ModuleList()
332
+ for head_idx in range(self.num_heads):
333
+ # Parse in_channels.
334
+ if self.use_sam:
335
+ in_channels = self.sam_channels
336
+ elif self.use_fpn:
337
+ in_channels = self.fpn_channels
338
+ else:
339
+ in_channels = self.stage_channels[head_idx - self.num_heads]
340
+ in_channels = in_channels * final_res * final_res
341
+
342
+ # Parse out_channels.
343
+ start_latent_idx = sum(self.num_latents_per_head[:head_idx])
344
+ end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
345
+ out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx])
346
+
347
+ self.heads.append(CodeHead(in_channels=in_channels,
348
+ out_channels=out_channels,
349
+ norm_layer=self.norm_layer))
350
+
351
+ def _make_stage(self,
352
+ block_fn,
353
+ inplanes,
354
+ planes,
355
+ num_blocks,
356
+ stride,
357
+ dilate):
358
+ norm_layer = self.norm_layer
359
+ downsample = None
360
+ previous_dilation = self.dilation
361
+ if dilate:
362
+ self.dilation *= stride
363
+ stride = 1
364
+ if stride != 1 or inplanes != planes * block_fn.expansion:
365
+ downsample = nn.Sequential(
366
+ nn.Conv2d(in_channels=inplanes,
367
+ out_channels=planes * block_fn.expansion,
368
+ kernel_size=1,
369
+ stride=stride,
370
+ padding=0,
371
+ dilation=1,
372
+ groups=1,
373
+ bias=False),
374
+ norm_layer(planes * block_fn.expansion),
375
+ )
376
+
377
+ blocks = []
378
+ blocks.append(block_fn(inplanes=inplanes,
379
+ planes=planes,
380
+ base_width=self.base_width,
381
+ stride=stride,
382
+ groups=self.groups,
383
+ dilation=previous_dilation,
384
+ norm_layer=norm_layer,
385
+ downsample=downsample))
386
+ for _ in range(1, num_blocks):
387
+ blocks.append(block_fn(inplanes=planes * block_fn.expansion,
388
+ planes=planes,
389
+ base_width=self.base_width,
390
+ stride=1,
391
+ groups=self.groups,
392
+ dilation=self.dilation,
393
+ norm_layer=norm_layer,
394
+ downsample=None))
395
+
396
+ return nn.Sequential(*blocks)
397
+
398
+ def forward(self, x):
399
+ x = self.conv1(x)
400
+ x = self.bn1(x)
401
+ x = self.relu(x)
402
+ x = self.maxpool(x)
403
+
404
+ features = [x]
405
+ for i in range(self.num_stages):
406
+ x = self.stages[i](x)
407
+ features.append(x)
408
+ features = features[-self.num_heads:]
409
+
410
+ if self.use_fpn:
411
+ features = self.fpn(features)
412
+ if self.use_sam:
413
+ features = self.sam(features)
414
+ else:
415
+ final_size = features[-1].shape[2:]
416
+ for i in range(self.num_heads - 1):
417
+ features[i] = F.adaptive_avg_pool2d(features[i], final_size)
418
+
419
+ outputs = []
420
+ for head_idx in range(self.num_heads):
421
+ codes = self.heads[head_idx](features[head_idx])
422
+ start_latent_idx = sum(self.num_latents_per_head[:head_idx])
423
+ end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
424
+ split_size = self.latent_dim[start_latent_idx:end_latent_idx]
425
+ outputs.extend(torch.split(codes, split_size, dim=1))
426
+ max_dim = max(self.latent_dim)
427
+ for i, dim in enumerate(self.latent_dim):
428
+ if dim < max_dim:
429
+ outputs[i] = F.pad(outputs[i], (0, max_dim - dim))
430
+ outputs[i] = outputs[i].unsqueeze(1)
431
+
432
+ return torch.cat(outputs, dim=1)
433
+
434
+
435
+ class FPN(nn.Module):
436
+ """Implementation of Feature Pyramid Network (FPN).
437
+
438
+ The input of this module is a pyramid of features with reducing resolutions.
439
+ Then, this module fuses these multi-level features from `top_level` to
440
+ `bottom_level`. In particular, starting from the `top_level`, each feature
441
+ is convoluted, upsampled, and fused into its previous feature (which is also
442
+ convoluted).
443
+
444
+ Args:
445
+ pyramid_channels: A list of integers, each of which indicates the number
446
+ of channels of the feature from a particular level.
447
+ out_channels: Number of channels for each output.
448
+
449
+ Returns:
450
+ A list of feature maps, each of which has `out_channels` channels.
451
+ """
452
+
453
+ def __init__(self, pyramid_channels, out_channels):
454
+ super().__init__()
455
+ assert isinstance(pyramid_channels, (list, tuple))
456
+ self.num_levels = len(pyramid_channels)
457
+
458
+ self.lateral_layers = nn.ModuleList()
459
+ self.feature_layers = nn.ModuleList()
460
+ for i in range(self.num_levels):
461
+ in_channels = pyramid_channels[i]
462
+ self.lateral_layers.append(nn.Conv2d(in_channels=in_channels,
463
+ out_channels=out_channels,
464
+ kernel_size=3,
465
+ padding=1,
466
+ bias=True))
467
+ self.feature_layers.append(nn.Conv2d(in_channels=out_channels,
468
+ out_channels=out_channels,
469
+ kernel_size=3,
470
+ padding=1,
471
+ bias=True))
472
+
473
+ def forward(self, inputs):
474
+ if len(inputs) != self.num_levels:
475
+ raise ValueError('Number of inputs and `num_levels` mismatch!')
476
+
477
+ # Project all related features to `out_channels`.
478
+ laterals = []
479
+ for i in range(self.num_levels):
480
+ laterals.append(self.lateral_layers[i](inputs[i]))
481
+
482
+ # Fusion, starting from `top_level`.
483
+ for i in range(self.num_levels - 1, 0, -1):
484
+ scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2]
485
+ laterals[i - 1] = (laterals[i - 1] +
486
+ F.interpolate(laterals[i],
487
+ mode='nearest',
488
+ scale_factor=scale_factor))
489
+
490
+ # Get outputs.
491
+ outputs = []
492
+ for i, lateral in enumerate(laterals):
493
+ outputs.append(self.feature_layers[i](lateral))
494
+
495
+ return outputs
496
+
497
+
498
+ class SAM(nn.Module):
499
+ """Implementation of Spatial Alignment Module (SAM).
500
+
501
+ The input of this module is a pyramid of features with reducing resolutions.
502
+ Then this module downsamples all levels of feature to the minimum resolution
503
+ and fuses it with the smallest feature map.
504
+
505
+ Args:
506
+ pyramid_channels: A list of integers, each of which indicates the number
507
+ of channels of the feature from a particular level.
508
+ out_channels: Number of channels for each output.
509
+
510
+ Returns:
511
+ A list of feature maps, each of which has `out_channels` channels.
512
+ """
513
+
514
+ def __init__(self, pyramid_channels, out_channels):
515
+ super().__init__()
516
+ assert isinstance(pyramid_channels, (list, tuple))
517
+ self.num_levels = len(pyramid_channels)
518
+
519
+ self.fusion_layers = nn.ModuleList()
520
+ for i in range(self.num_levels):
521
+ in_channels = pyramid_channels[i]
522
+ self.fusion_layers.append(nn.Conv2d(in_channels=in_channels,
523
+ out_channels=out_channels,
524
+ kernel_size=3,
525
+ padding=1,
526
+ bias=True))
527
+
528
+ def forward(self, inputs):
529
+ if len(inputs) != self.num_levels:
530
+ raise ValueError('Number of inputs and `num_levels` mismatch!')
531
+
532
+ output_res = inputs[-1].shape[2:]
533
+ for i in range(self.num_levels - 1, -1, -1):
534
+ if i != self.num_levels - 1:
535
+ inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res)
536
+ inputs[i] = self.fusion_layers[i](inputs[i])
537
+ if i != self.num_levels - 1:
538
+ inputs[i] = inputs[i] + inputs[-1]
539
+
540
+ return inputs
541
+
542
+
543
+ class CodeHead(nn.Module):
544
+ """Implementation of the task-head to produce inverted codes."""
545
+
546
+ def __init__(self, in_channels, out_channels, norm_layer):
547
+ super().__init__()
548
+ self.fc = nn.Linear(in_channels, out_channels, bias=True)
549
+ if norm_layer is None:
550
+ self.norm = nn.Identity()
551
+ else:
552
+ self.norm = norm_layer(out_channels)
553
+
554
+ def forward(self, x):
555
+ if x.ndim > 2:
556
+ x = x.flatten(start_dim=1)
557
+ latent = self.fc(x)
558
+ latent = latent.unsqueeze(2).unsqueeze(3)
559
+ latent = self.norm(latent)
560
+
561
+ return latent.flatten(start_dim=1)
562
+
563
+ # pylint: enable=missing-function-docstring
models/inception_model.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the Inception V3 model, which is used for inference ONLY.
3
+
4
+ This file is mostly borrowed from `torchvision/models/inception.py`.
5
+
6
+ Inception model is widely used to compute FID or IS metric for evaluating
7
+ generative models. However, the pre-trained models from torchvision is slightly
8
+ different from the TensorFlow version
9
+
10
+ http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
11
+
12
+ which is used by the official FID implementation
13
+
14
+ https://github.com/bioinf-jku/TTUR
15
+
16
+ In particular:
17
+
18
+ (1) The number of classes in TensorFlow model is 1008 instead of 1000.
19
+ (2) The avg_pool() layers in TensorFlow model does not include the padded zero.
20
+ (3) The last Inception E Block in TensorFlow model use max_pool() instead of
21
+ avg_pool().
22
+
23
+ Hence, to align the evaluation results with those from TensorFlow
24
+ implementation, we modified the inception model to support both versions. Please
25
+ use `align_tf` argument to control the version.
26
+ """
27
+
28
+ import warnings
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ import torch.distributed as dist
34
+
35
+ from utils.misc import download_url
36
+
37
+ __all__ = ['InceptionModel']
38
+
39
+ # pylint: disable=line-too-long
40
+
41
+ _MODEL_URL_SHA256 = {
42
+ # This model is provided by `torchvision`, which is ported from TensorFlow.
43
+ 'torchvision_official': (
44
+ 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
45
+ '1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256
46
+ ),
47
+
48
+ # This model is provided by https://github.com/mseitzer/pytorch-fid
49
+ 'tf_inception_v3': (
50
+ 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth',
51
+ '6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256
52
+ )
53
+ }
54
+
55
+
56
+ class InceptionModel(object):
57
+ """Defines the Inception (V3) model.
58
+
59
+ This is a static class, which is used to avoid this model to be built
60
+ repeatedly. Consequently, this model is particularly used for inference,
61
+ like computing FID. If training is required, please use the model from
62
+ `torchvision.models` or implement by yourself.
63
+
64
+ NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
65
+ order and pixel range [-1, 1], and will also resize the images to shape
66
+ [299, 299] automatically. If your input is normalized by subtracting
67
+ (0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use
68
+ `transform_input` in the `forward()` function to un-normalize it.
69
+ """
70
+ models = dict()
71
+
72
+ @staticmethod
73
+ def build_model(align_tf=True):
74
+ """Builds the model and load pre-trained weights.
75
+
76
+ If `align_tf` is set as True, the model will predict 1008 classes, and
77
+ the pre-trained weight from `https://github.com/mseitzer/pytorch-fid`
78
+ will be loaded. Otherwise, the model will predict 1000 classes, and will
79
+ load the model from `torchvision`.
80
+
81
+ The built model supports following arguments when forwarding:
82
+
83
+ - transform_input: Whether to transform the input back to pixel range
84
+ (-1, 1). Please disable this argument if your input is already with
85
+ pixel range (-1, 1). (default: False)
86
+ - output_logits: Whether to output the categorical logits instead of
87
+ features. (default: False)
88
+ - remove_logits_bias: Whether to remove the bias when computing the
89
+ logits. The official implementation removes the bias by default.
90
+ Please refer to
91
+ `https://github.com/openai/improved-gan/blob/master/inception_score/model.py`.
92
+ (default: False)
93
+ - output_predictions: Whether to output the final predictions, i.e.,
94
+ `softmax(logits)`. (default: False)
95
+ """
96
+ if align_tf:
97
+ num_classes = 1008
98
+ model_source = 'tf_inception_v3'
99
+ else:
100
+ num_classes = 1000
101
+ model_source = 'torchvision_official'
102
+
103
+ fingerprint = model_source
104
+
105
+ if fingerprint not in InceptionModel.models:
106
+ # Build model.
107
+ model = Inception3(num_classes=num_classes,
108
+ aux_logits=False,
109
+ init_weights=False,
110
+ align_tf=align_tf)
111
+
112
+ # Download pre-trained weights.
113
+ if dist.is_initialized() and dist.get_rank() != 0:
114
+ dist.barrier() # Download by chief.
115
+
116
+ url, sha256 = _MODEL_URL_SHA256[model_source]
117
+ filename = f'inception_model_{model_source}_{sha256}.pth'
118
+ model_path, hash_check = download_url(url,
119
+ filename=filename,
120
+ sha256=sha256)
121
+ state_dict = torch.load(model_path, map_location='cpu')
122
+ if hash_check is False:
123
+ warnings.warn(f'Hash check failed! The remote file from URL '
124
+ f'`{url}` may be changed, or the downloading is '
125
+ f'interrupted. The loaded inception model may '
126
+ f'have unexpected behavior.')
127
+
128
+ if dist.is_initialized() and dist.get_rank() == 0:
129
+ dist.barrier() # Wait for other replicas.
130
+
131
+ # Load weights.
132
+ model.load_state_dict(state_dict, strict=False)
133
+ del state_dict
134
+
135
+ # For inference only.
136
+ model.eval().requires_grad_(False).cuda()
137
+ InceptionModel.models[fingerprint] = model
138
+
139
+ return InceptionModel.models[fingerprint]
140
+
141
+ # pylint: disable=missing-function-docstring
142
+ # pylint: disable=missing-class-docstring
143
+ # pylint: disable=super-with-arguments
144
+ # pylint: disable=consider-merging-isinstance
145
+ # pylint: disable=import-outside-toplevel
146
+ # pylint: disable=no-else-return
147
+
148
+ class Inception3(nn.Module):
149
+
150
+ def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None,
151
+ init_weights=True, align_tf=True):
152
+ super(Inception3, self).__init__()
153
+ if inception_blocks is None:
154
+ inception_blocks = [
155
+ BasicConv2d, InceptionA, InceptionB, InceptionC,
156
+ InceptionD, InceptionE, InceptionAux
157
+ ]
158
+ assert len(inception_blocks) == 7
159
+ conv_block = inception_blocks[0]
160
+ inception_a = inception_blocks[1]
161
+ inception_b = inception_blocks[2]
162
+ inception_c = inception_blocks[3]
163
+ inception_d = inception_blocks[4]
164
+ inception_e = inception_blocks[5]
165
+ inception_aux = inception_blocks[6]
166
+
167
+ self.aux_logits = aux_logits
168
+ self.align_tf = align_tf
169
+ self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
170
+ self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
171
+ self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
172
+ self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
173
+ self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
174
+ self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
175
+ self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
176
+ self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
177
+ self.Mixed_6a = inception_b(288)
178
+ self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
179
+ self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
180
+ self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
181
+ self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
182
+ if aux_logits:
183
+ self.AuxLogits = inception_aux(768, num_classes)
184
+ self.Mixed_7a = inception_d(768)
185
+ self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
186
+ self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
187
+ self.fc = nn.Linear(2048, num_classes)
188
+ if init_weights:
189
+ for m in self.modules():
190
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
191
+ import scipy.stats as stats
192
+ stddev = m.stddev if hasattr(m, 'stddev') else 0.1
193
+ X = stats.truncnorm(-2, 2, scale=stddev)
194
+ values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
195
+ values = values.view(m.weight.size())
196
+ with torch.no_grad():
197
+ m.weight.copy_(values)
198
+ elif isinstance(m, nn.BatchNorm2d):
199
+ nn.init.constant_(m.weight, 1)
200
+ nn.init.constant_(m.bias, 0)
201
+
202
+ @staticmethod
203
+ def _transform_input(x, transform_input=False):
204
+ if transform_input:
205
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
206
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
207
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
208
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
209
+ return x
210
+
211
+ def _forward(self,
212
+ x,
213
+ output_logits=False,
214
+ remove_logits_bias=False,
215
+ output_predictions=False):
216
+ # Upsample if necessary.
217
+ if x.shape[2] != 299 or x.shape[3] != 299:
218
+ if self.align_tf:
219
+ theta = torch.eye(2, 3).to(x)
220
+ theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299
221
+ theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299
222
+ theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
223
+ grid = F.affine_grid(theta,
224
+ size=(x.shape[0], x.shape[1], 299, 299),
225
+ align_corners=False)
226
+ x = F.grid_sample(x, grid,
227
+ mode='bilinear',
228
+ padding_mode='border',
229
+ align_corners=False)
230
+ else:
231
+ x = F.interpolate(
232
+ x, size=(299, 299), mode='bilinear', align_corners=False)
233
+ if x.shape[1] == 1:
234
+ x = x.repeat((1, 3, 1, 1))
235
+
236
+ if self.align_tf:
237
+ x = (x * 127.5 + 127.5 - 128) / 128
238
+
239
+ # N x 3 x 299 x 299
240
+ x = self.Conv2d_1a_3x3(x)
241
+ # N x 32 x 149 x 149
242
+ x = self.Conv2d_2a_3x3(x)
243
+ # N x 32 x 147 x 147
244
+ x = self.Conv2d_2b_3x3(x)
245
+ # N x 64 x 147 x 147
246
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
247
+ # N x 64 x 73 x 73
248
+ x = self.Conv2d_3b_1x1(x)
249
+ # N x 80 x 73 x 73
250
+ x = self.Conv2d_4a_3x3(x)
251
+ # N x 192 x 71 x 71
252
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
253
+ # N x 192 x 35 x 35
254
+ x = self.Mixed_5b(x)
255
+ # N x 256 x 35 x 35
256
+ x = self.Mixed_5c(x)
257
+ # N x 288 x 35 x 35
258
+ x = self.Mixed_5d(x)
259
+ # N x 288 x 35 x 35
260
+ x = self.Mixed_6a(x)
261
+ # N x 768 x 17 x 17
262
+ x = self.Mixed_6b(x)
263
+ # N x 768 x 17 x 17
264
+ x = self.Mixed_6c(x)
265
+ # N x 768 x 17 x 17
266
+ x = self.Mixed_6d(x)
267
+ # N x 768 x 17 x 17
268
+ x = self.Mixed_6e(x)
269
+ # N x 768 x 17 x 17
270
+ if self.training and self.aux_logits:
271
+ aux = self.AuxLogits(x)
272
+ else:
273
+ aux = None
274
+ # N x 768 x 17 x 17
275
+ x = self.Mixed_7a(x)
276
+ # N x 1280 x 8 x 8
277
+ x = self.Mixed_7b(x)
278
+ # N x 2048 x 8 x 8
279
+ x = self.Mixed_7c(x)
280
+ # N x 2048 x 8 x 8
281
+ # Adaptive average pooling
282
+ x = F.adaptive_avg_pool2d(x, (1, 1))
283
+ # N x 2048 x 1 x 1
284
+ x = F.dropout(x, training=self.training)
285
+ # N x 2048 x 1 x 1
286
+ x = torch.flatten(x, 1)
287
+ # N x 2048
288
+ if output_logits or output_predictions:
289
+ x = self.fc(x)
290
+ # N x 1000 (num_classes)
291
+ if remove_logits_bias:
292
+ x = x - self.fc.bias.view(1, -1)
293
+ if output_predictions:
294
+ x = F.softmax(x, dim=1)
295
+ return x, aux
296
+
297
+ def forward(self,
298
+ x,
299
+ transform_input=False,
300
+ output_logits=False,
301
+ remove_logits_bias=False,
302
+ output_predictions=False):
303
+ x = self._transform_input(x, transform_input)
304
+ x, aux = self._forward(
305
+ x, output_logits, remove_logits_bias, output_predictions)
306
+ if self.training and self.aux_logits:
307
+ return x, aux
308
+ else:
309
+ return x
310
+
311
+
312
+ class InceptionA(nn.Module):
313
+
314
+ def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
315
+ super(InceptionA, self).__init__()
316
+ if conv_block is None:
317
+ conv_block = BasicConv2d
318
+ self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
319
+
320
+ self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
321
+ self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
322
+
323
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
324
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
325
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
326
+
327
+ self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
328
+ self.pool_include_padding = not align_tf
329
+
330
+ def _forward(self, x):
331
+ branch1x1 = self.branch1x1(x)
332
+
333
+ branch5x5 = self.branch5x5_1(x)
334
+ branch5x5 = self.branch5x5_2(branch5x5)
335
+
336
+ branch3x3dbl = self.branch3x3dbl_1(x)
337
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
338
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
339
+
340
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
341
+ count_include_pad=self.pool_include_padding)
342
+ branch_pool = self.branch_pool(branch_pool)
343
+
344
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
345
+ return outputs
346
+
347
+ def forward(self, x):
348
+ outputs = self._forward(x)
349
+ return torch.cat(outputs, 1)
350
+
351
+
352
+ class InceptionB(nn.Module):
353
+
354
+ def __init__(self, in_channels, conv_block=None):
355
+ super(InceptionB, self).__init__()
356
+ if conv_block is None:
357
+ conv_block = BasicConv2d
358
+ self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
359
+
360
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
361
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
362
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
363
+
364
+ def _forward(self, x):
365
+ branch3x3 = self.branch3x3(x)
366
+
367
+ branch3x3dbl = self.branch3x3dbl_1(x)
368
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
369
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
370
+
371
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
372
+
373
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
374
+ return outputs
375
+
376
+ def forward(self, x):
377
+ outputs = self._forward(x)
378
+ return torch.cat(outputs, 1)
379
+
380
+
381
+ class InceptionC(nn.Module):
382
+
383
+ def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
384
+ super(InceptionC, self).__init__()
385
+ if conv_block is None:
386
+ conv_block = BasicConv2d
387
+ self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
388
+
389
+ c7 = channels_7x7
390
+ self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
391
+ self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
392
+ self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
393
+
394
+ self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
395
+ self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
396
+ self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
397
+ self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
398
+ self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
399
+
400
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
401
+ self.pool_include_padding = not align_tf
402
+
403
+ def _forward(self, x):
404
+ branch1x1 = self.branch1x1(x)
405
+
406
+ branch7x7 = self.branch7x7_1(x)
407
+ branch7x7 = self.branch7x7_2(branch7x7)
408
+ branch7x7 = self.branch7x7_3(branch7x7)
409
+
410
+ branch7x7dbl = self.branch7x7dbl_1(x)
411
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
412
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
413
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
414
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
415
+
416
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
417
+ count_include_pad=self.pool_include_padding)
418
+ branch_pool = self.branch_pool(branch_pool)
419
+
420
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
421
+ return outputs
422
+
423
+ def forward(self, x):
424
+ outputs = self._forward(x)
425
+ return torch.cat(outputs, 1)
426
+
427
+
428
+ class InceptionD(nn.Module):
429
+
430
+ def __init__(self, in_channels, conv_block=None):
431
+ super(InceptionD, self).__init__()
432
+ if conv_block is None:
433
+ conv_block = BasicConv2d
434
+ self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
435
+ self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
436
+
437
+ self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
438
+ self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
439
+ self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
440
+ self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
441
+
442
+ def _forward(self, x):
443
+ branch3x3 = self.branch3x3_1(x)
444
+ branch3x3 = self.branch3x3_2(branch3x3)
445
+
446
+ branch7x7x3 = self.branch7x7x3_1(x)
447
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
448
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
449
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
450
+
451
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
452
+ outputs = [branch3x3, branch7x7x3, branch_pool]
453
+ return outputs
454
+
455
+ def forward(self, x):
456
+ outputs = self._forward(x)
457
+ return torch.cat(outputs, 1)
458
+
459
+
460
+ class InceptionE(nn.Module):
461
+
462
+ def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
463
+ super(InceptionE, self).__init__()
464
+ if conv_block is None:
465
+ conv_block = BasicConv2d
466
+ self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
467
+
468
+ self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
469
+ self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
470
+ self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
471
+
472
+ self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
473
+ self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
474
+ self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
475
+ self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
476
+
477
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
478
+ self.pool_include_padding = not align_tf
479
+ self.use_max_pool = use_max_pool
480
+
481
+ def _forward(self, x):
482
+ branch1x1 = self.branch1x1(x)
483
+
484
+ branch3x3 = self.branch3x3_1(x)
485
+ branch3x3 = [
486
+ self.branch3x3_2a(branch3x3),
487
+ self.branch3x3_2b(branch3x3),
488
+ ]
489
+ branch3x3 = torch.cat(branch3x3, 1)
490
+
491
+ branch3x3dbl = self.branch3x3dbl_1(x)
492
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
493
+ branch3x3dbl = [
494
+ self.branch3x3dbl_3a(branch3x3dbl),
495
+ self.branch3x3dbl_3b(branch3x3dbl),
496
+ ]
497
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
498
+
499
+ if self.use_max_pool:
500
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
501
+ else:
502
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
503
+ count_include_pad=self.pool_include_padding)
504
+ branch_pool = self.branch_pool(branch_pool)
505
+
506
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
507
+ return outputs
508
+
509
+ def forward(self, x):
510
+ outputs = self._forward(x)
511
+ return torch.cat(outputs, 1)
512
+
513
+
514
+ class InceptionAux(nn.Module):
515
+
516
+ def __init__(self, in_channels, num_classes, conv_block=None):
517
+ super(InceptionAux, self).__init__()
518
+ if conv_block is None:
519
+ conv_block = BasicConv2d
520
+ self.conv0 = conv_block(in_channels, 128, kernel_size=1)
521
+ self.conv1 = conv_block(128, 768, kernel_size=5)
522
+ self.conv1.stddev = 0.01
523
+ self.fc = nn.Linear(768, num_classes)
524
+ self.fc.stddev = 0.001
525
+
526
+ def forward(self, x):
527
+ # N x 768 x 17 x 17
528
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
529
+ # N x 768 x 5 x 5
530
+ x = self.conv0(x)
531
+ # N x 128 x 5 x 5
532
+ x = self.conv1(x)
533
+ # N x 768 x 1 x 1
534
+ # Adaptive average pooling
535
+ x = F.adaptive_avg_pool2d(x, (1, 1))
536
+ # N x 768 x 1 x 1
537
+ x = torch.flatten(x, 1)
538
+ # N x 768
539
+ x = self.fc(x)
540
+ # N x 1000
541
+ return x
542
+
543
+
544
+ class BasicConv2d(nn.Module):
545
+
546
+ def __init__(self, in_channels, out_channels, **kwargs):
547
+ super(BasicConv2d, self).__init__()
548
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
549
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
550
+
551
+ def forward(self, x):
552
+ x = self.conv(x)
553
+ x = self.bn(x)
554
+ return F.relu(x, inplace=True)
555
+
556
+ # pylint: enable=line-too-long
557
+ # pylint: enable=missing-function-docstring
558
+ # pylint: enable=missing-class-docstring
559
+ # pylint: enable=super-with-arguments
560
+ # pylint: enable=consider-merging-isinstance
561
+ # pylint: enable=import-outside-toplevel
562
+ # pylint: enable=no-else-return
models/perceptual_model.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the VGG16 model, which is used for inference ONLY.
3
+
4
+ VGG16 is commonly used for perceptual feature extraction. The model implemented
5
+ in this file can be used for evaluation (like computing LPIPS, perceptual path
6
+ length, etc.), OR be used in training for loss computation (like perceptual
7
+ loss, etc.).
8
+
9
+ The pre-trained model is officially shared by
10
+
11
+ https://www.robots.ox.ac.uk/~vgg/research/very_deep/
12
+
13
+ and ported by
14
+
15
+ https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
16
+
17
+ Compared to the official VGG16 model, this ported model also support evaluating
18
+ LPIPS, which is introduced in
19
+
20
+ https://github.com/richzhang/PerceptualSimilarity
21
+ """
22
+
23
+ import warnings
24
+ import numpy as np
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torch.distributed as dist
30
+
31
+ from utils.misc import download_url
32
+
33
+ __all__ = ['PerceptualModel']
34
+
35
+ # pylint: disable=line-too-long
36
+ _MODEL_URL_SHA256 = {
37
+ # This model is provided by `torchvision`, which is ported from TensorFlow.
38
+ 'torchvision_official': (
39
+ 'https://download.pytorch.org/models/vgg16-397923af.pth',
40
+ '397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256
41
+ ),
42
+
43
+ # This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch
44
+ 'vgg_perceptual_lpips': (
45
+ 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt',
46
+ 'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256
47
+ )
48
+ }
49
+ # pylint: enable=line-too-long
50
+
51
+
52
+ class PerceptualModel(object):
53
+ """Defines the perceptual model, which is based on VGG16 structure.
54
+
55
+ This is a static class, which is used to avoid this model to be built
56
+ repeatedly. Consequently, this model is particularly used for inference,
57
+ like computing LPIPS, or for loss computation, like perceptual loss. If
58
+ training is required, please use the model from `torchvision.models` or
59
+ implement by yourself.
60
+
61
+ NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
62
+ order and pixel range [-1, 1], and will NOT resize the input automatically
63
+ if only perceptual feature is needed.
64
+ """
65
+ models = dict()
66
+
67
+ @staticmethod
68
+ def build_model(use_torchvision=False, no_top=True, enable_lpips=True):
69
+ """Builds the model and load pre-trained weights.
70
+
71
+ 1. If `use_torchvision` is set as True, the model released by
72
+ `torchvision` will be loaded, otherwise, the model released by
73
+ https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used.
74
+ (default: False)
75
+
76
+ 2. To save computing resources, these is an option to only load the
77
+ backbone (i.e., without the last three fully-connected layers). This
78
+ is commonly used for perceptual loss or LPIPS loss computation.
79
+ Please use argument `no_top` to control this. (default: True)
80
+
81
+ 3. For LPIPS loss computation, some additional weights (which is used
82
+ for balancing the features from different resolutions) are employed
83
+ on top of the original VGG16 backbone. Details can be found at
84
+ https://github.com/richzhang/PerceptualSimilarity. Please use
85
+ `enable_lpips` to enable this feature. (default: True)
86
+
87
+ The built model supports following arguments when forwarding:
88
+
89
+ - resize_input: Whether to resize the input image to size [224, 224]
90
+ before forwarding. For feature-based computation (i.e., only
91
+ convolutional layers are used), image resizing is not essential.
92
+ (default: False)
93
+ - return_tensor: This field resolves the model behavior. Following
94
+ options are supported:
95
+ `feature1`: Before the first max pooling layer.
96
+ `pool1`: After the first max pooling layer.
97
+ `feature2`: Before the second max pooling layer.
98
+ `pool2`: After the second max pooling layer.
99
+ `feature3`: Before the third max pooling layer.
100
+ `pool3`: After the third max pooling layer.
101
+ `feature4`: Before the fourth max pooling layer.
102
+ `pool4`: After the fourth max pooling layer.
103
+ `feature5`: Before the fifth max pooling layer.
104
+ `pool5`: After the fifth max pooling layer.
105
+ `flatten`: The flattened feature, after `adaptive_avgpool`.
106
+ `feature`: The 4096d feature for logits computation. (default)
107
+ `logits`: The 1000d categorical logits.
108
+ `prediction`: The 1000d predicted probability.
109
+ `lpips`: The LPIPS score between two input images.
110
+ """
111
+ if use_torchvision:
112
+ model_source = 'torchvision_official'
113
+ align_tf_resize = False
114
+ is_torch_script = False
115
+ else:
116
+ model_source = 'vgg_perceptual_lpips'
117
+ align_tf_resize = True
118
+ is_torch_script = True
119
+
120
+ if enable_lpips and model_source != 'vgg_perceptual_lpips':
121
+ warnings.warn('The pre-trained model officially released by '
122
+ '`torchvision` does not support LPIPS computation! '
123
+ 'Equal weights will be used for each resolution.')
124
+
125
+ fingerprint = (model_source, no_top, enable_lpips)
126
+
127
+ if fingerprint not in PerceptualModel.models:
128
+ # Build model.
129
+ model = VGG16(align_tf_resize=align_tf_resize,
130
+ no_top=no_top,
131
+ enable_lpips=enable_lpips)
132
+
133
+ # Download pre-trained weights.
134
+ if dist.is_initialized() and dist.get_rank() != 0:
135
+ dist.barrier() # Download by chief.
136
+
137
+ url, sha256 = _MODEL_URL_SHA256[model_source]
138
+ filename = f'perceptual_model_{model_source}_{sha256}.pth'
139
+ model_path, hash_check = download_url(url,
140
+ filename=filename,
141
+ sha256=sha256)
142
+ if is_torch_script:
143
+ src_state_dict = torch.jit.load(model_path, map_location='cpu')
144
+ else:
145
+ src_state_dict = torch.load(model_path, map_location='cpu')
146
+ if hash_check is False:
147
+ warnings.warn(f'Hash check failed! The remote file from URL '
148
+ f'`{url}` may be changed, or the downloading is '
149
+ f'interrupted. The loaded perceptual model may '
150
+ f'have unexpected behavior.')
151
+
152
+ if dist.is_initialized() and dist.get_rank() == 0:
153
+ dist.barrier() # Wait for other replicas.
154
+
155
+ # Load weights.
156
+ dst_state_dict = _convert_weights(src_state_dict, model_source)
157
+ model.load_state_dict(dst_state_dict, strict=False)
158
+ del src_state_dict, dst_state_dict
159
+
160
+ # For inference only.
161
+ model.eval().requires_grad_(False).cuda()
162
+ PerceptualModel.models[fingerprint] = model
163
+
164
+ return PerceptualModel.models[fingerprint]
165
+
166
+
167
+ def _convert_weights(src_state_dict, model_source):
168
+ if model_source not in _MODEL_URL_SHA256:
169
+ raise ValueError(f'Invalid model source `{model_source}`!\n'
170
+ f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.')
171
+ if model_source == 'torchvision_official':
172
+ dst_to_src_var_mapping = {
173
+ 'conv11.weight': 'features.0.weight',
174
+ 'conv11.bias': 'features.0.bias',
175
+ 'conv12.weight': 'features.2.weight',
176
+ 'conv12.bias': 'features.2.bias',
177
+ 'conv21.weight': 'features.5.weight',
178
+ 'conv21.bias': 'features.5.bias',
179
+ 'conv22.weight': 'features.7.weight',
180
+ 'conv22.bias': 'features.7.bias',
181
+ 'conv31.weight': 'features.10.weight',
182
+ 'conv31.bias': 'features.10.bias',
183
+ 'conv32.weight': 'features.12.weight',
184
+ 'conv32.bias': 'features.12.bias',
185
+ 'conv33.weight': 'features.14.weight',
186
+ 'conv33.bias': 'features.14.bias',
187
+ 'conv41.weight': 'features.17.weight',
188
+ 'conv41.bias': 'features.17.bias',
189
+ 'conv42.weight': 'features.19.weight',
190
+ 'conv42.bias': 'features.19.bias',
191
+ 'conv43.weight': 'features.21.weight',
192
+ 'conv43.bias': 'features.21.bias',
193
+ 'conv51.weight': 'features.24.weight',
194
+ 'conv51.bias': 'features.24.bias',
195
+ 'conv52.weight': 'features.26.weight',
196
+ 'conv52.bias': 'features.26.bias',
197
+ 'conv53.weight': 'features.28.weight',
198
+ 'conv53.bias': 'features.28.bias',
199
+ 'fc1.weight': 'classifier.0.weight',
200
+ 'fc1.bias': 'classifier.0.bias',
201
+ 'fc2.weight': 'classifier.3.weight',
202
+ 'fc2.bias': 'classifier.3.bias',
203
+ 'fc3.weight': 'classifier.6.weight',
204
+ 'fc3.bias': 'classifier.6.bias',
205
+ }
206
+ elif model_source == 'vgg_perceptual_lpips':
207
+ src_state_dict = src_state_dict.state_dict()
208
+ dst_to_src_var_mapping = {
209
+ 'conv11.weight': 'layers.conv1.weight',
210
+ 'conv11.bias': 'layers.conv1.bias',
211
+ 'conv12.weight': 'layers.conv2.weight',
212
+ 'conv12.bias': 'layers.conv2.bias',
213
+ 'conv21.weight': 'layers.conv3.weight',
214
+ 'conv21.bias': 'layers.conv3.bias',
215
+ 'conv22.weight': 'layers.conv4.weight',
216
+ 'conv22.bias': 'layers.conv4.bias',
217
+ 'conv31.weight': 'layers.conv5.weight',
218
+ 'conv31.bias': 'layers.conv5.bias',
219
+ 'conv32.weight': 'layers.conv6.weight',
220
+ 'conv32.bias': 'layers.conv6.bias',
221
+ 'conv33.weight': 'layers.conv7.weight',
222
+ 'conv33.bias': 'layers.conv7.bias',
223
+ 'conv41.weight': 'layers.conv8.weight',
224
+ 'conv41.bias': 'layers.conv8.bias',
225
+ 'conv42.weight': 'layers.conv9.weight',
226
+ 'conv42.bias': 'layers.conv9.bias',
227
+ 'conv43.weight': 'layers.conv10.weight',
228
+ 'conv43.bias': 'layers.conv10.bias',
229
+ 'conv51.weight': 'layers.conv11.weight',
230
+ 'conv51.bias': 'layers.conv11.bias',
231
+ 'conv52.weight': 'layers.conv12.weight',
232
+ 'conv52.bias': 'layers.conv12.bias',
233
+ 'conv53.weight': 'layers.conv13.weight',
234
+ 'conv53.bias': 'layers.conv13.bias',
235
+ 'fc1.weight': 'layers.fc1.weight',
236
+ 'fc1.bias': 'layers.fc1.bias',
237
+ 'fc2.weight': 'layers.fc2.weight',
238
+ 'fc2.bias': 'layers.fc2.bias',
239
+ 'fc3.weight': 'layers.fc3.weight',
240
+ 'fc3.bias': 'layers.fc3.bias',
241
+ 'lpips.0.weight': 'lpips0',
242
+ 'lpips.1.weight': 'lpips1',
243
+ 'lpips.2.weight': 'lpips2',
244
+ 'lpips.3.weight': 'lpips3',
245
+ 'lpips.4.weight': 'lpips4',
246
+ }
247
+ else:
248
+ raise NotImplementedError(f'Not implemented model source '
249
+ f'`{model_source}`!')
250
+
251
+ dst_state_dict = {}
252
+ for dst_name, src_name in dst_to_src_var_mapping.items():
253
+ if dst_name.startswith('lpips'):
254
+ dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0)
255
+ else:
256
+ dst_state_dict[dst_name] = src_state_dict[src_name].clone()
257
+ return dst_state_dict
258
+
259
+
260
+ _IMG_MEAN = (0.485, 0.456, 0.406)
261
+ _IMG_STD = (0.229, 0.224, 0.225)
262
+ _ALLOWED_RETURN = [
263
+ 'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4',
264
+ 'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction',
265
+ 'lpips'
266
+ ]
267
+
268
+ # pylint: disable=missing-function-docstring
269
+
270
+ class VGG16(nn.Module):
271
+ """Defines the VGG16 structure.
272
+
273
+ This model takes `RGB` images with data format `NCHW` as the raw inputs. The
274
+ pixel range are assumed to be [-1, 1].
275
+ """
276
+
277
+ def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True):
278
+ """Defines the network structure."""
279
+ super().__init__()
280
+
281
+ self.align_tf_resize = align_tf_resize
282
+ self.no_top = no_top
283
+ self.enable_lpips = enable_lpips
284
+
285
+ self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
286
+ self.relu11 = nn.ReLU(inplace=True)
287
+ self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
288
+ self.relu12 = nn.ReLU(inplace=True)
289
+ # output `feature1`, with shape [N, 64, 224, 224]
290
+
291
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
292
+ # output `pool1`, with shape [N, 64, 112, 112]
293
+
294
+ self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
295
+ self.relu21 = nn.ReLU(inplace=True)
296
+ self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
297
+ self.relu22 = nn.ReLU(inplace=True)
298
+ # output `feature2`, with shape [N, 128, 112, 112]
299
+
300
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
301
+ # output `pool2`, with shape [N, 128, 56, 56]
302
+
303
+ self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
304
+ self.relu31 = nn.ReLU(inplace=True)
305
+ self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
306
+ self.relu32 = nn.ReLU(inplace=True)
307
+ self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
308
+ self.relu33 = nn.ReLU(inplace=True)
309
+ # output `feature3`, with shape [N, 256, 56, 56]
310
+
311
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
312
+ # output `pool3`, with shape [N,256, 28, 28]
313
+
314
+ self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
315
+ self.relu41 = nn.ReLU(inplace=True)
316
+ self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
317
+ self.relu42 = nn.ReLU(inplace=True)
318
+ self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
319
+ self.relu43 = nn.ReLU(inplace=True)
320
+ # output `feature4`, with shape [N, 512, 28, 28]
321
+
322
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
323
+ # output `pool4`, with shape [N, 512, 14, 14]
324
+
325
+ self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
326
+ self.relu51 = nn.ReLU(inplace=True)
327
+ self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
328
+ self.relu52 = nn.ReLU(inplace=True)
329
+ self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
330
+ self.relu53 = nn.ReLU(inplace=True)
331
+ # output `feature5`, with shape [N, 512, 14, 14]
332
+
333
+ self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
334
+ # output `pool5`, with shape [N, 512, 7, 7]
335
+
336
+ if self.enable_lpips:
337
+ self.lpips = nn.ModuleList()
338
+ for idx, ch in enumerate([64, 128, 256, 512, 512]):
339
+ self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False))
340
+ self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1))
341
+
342
+ if not self.no_top:
343
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
344
+ self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
345
+ # output `flatten`, with shape [N, 25088]
346
+
347
+ self.fc1 = nn.Linear(512 * 7 * 7, 4096)
348
+ self.fc1_relu = nn.ReLU(inplace=True)
349
+ self.fc1_dropout = nn.Dropout(0.5, inplace=False)
350
+ self.fc2 = nn.Linear(4096, 4096)
351
+ self.fc2_relu = nn.ReLU(inplace=True)
352
+ self.fc2_dropout = nn.Dropout(0.5, inplace=False)
353
+ # output `feature`, with shape [N, 4096]
354
+
355
+ self.fc3 = nn.Linear(4096, 1000)
356
+ # output `logits`, with shape [N, 1000]
357
+
358
+ self.out = nn.Softmax(dim=1)
359
+ # output `softmax`, with shape [N, 1000]
360
+
361
+ img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32)
362
+ img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32)
363
+ self.register_buffer('img_mean', torch.from_numpy(img_mean))
364
+ self.register_buffer('img_std', torch.from_numpy(img_std))
365
+
366
+ def forward(self,
367
+ x,
368
+ y=None,
369
+ *,
370
+ resize_input=False,
371
+ return_tensor='feature'):
372
+ return_tensor = return_tensor.lower()
373
+ if return_tensor not in _ALLOWED_RETURN:
374
+ raise ValueError(f'Invalid output tensor name `{return_tensor}` '
375
+ f'for perceptual model (VGG16)!\n'
376
+ f'Names allowed: {_ALLOWED_RETURN}.')
377
+
378
+ if return_tensor == 'lpips' and y is None:
379
+ raise ValueError('Two images are required for LPIPS computation, '
380
+ 'but only one is received!')
381
+
382
+ if return_tensor == 'lpips':
383
+ assert x.shape == y.shape
384
+ x = torch.cat([x, y], dim=0)
385
+ features = []
386
+
387
+ if resize_input:
388
+ if self.align_tf_resize:
389
+ theta = torch.eye(2, 3).to(x)
390
+ theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224
391
+ theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224
392
+ theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
393
+ grid = F.affine_grid(theta,
394
+ size=(x.shape[0], x.shape[1], 224, 224),
395
+ align_corners=False)
396
+ x = F.grid_sample(x, grid,
397
+ mode='bilinear',
398
+ padding_mode='border',
399
+ align_corners=False)
400
+ else:
401
+ x = F.interpolate(x,
402
+ size=(224, 224),
403
+ mode='bilinear',
404
+ align_corners=False)
405
+ if x.shape[1] == 1:
406
+ x = x.repeat((1, 3, 1, 1))
407
+
408
+ x = (x + 1) / 2
409
+ x = (x - self.img_mean) / self.img_std
410
+
411
+ x = self.conv11(x)
412
+ x = self.relu11(x)
413
+ x = self.conv12(x)
414
+ x = self.relu12(x)
415
+ if return_tensor == 'feature1':
416
+ return x
417
+ if return_tensor == 'lpips':
418
+ features.append(x)
419
+
420
+ x = self.pool1(x)
421
+ if return_tensor == 'pool1':
422
+ return x
423
+
424
+ x = self.conv21(x)
425
+ x = self.relu21(x)
426
+ x = self.conv22(x)
427
+ x = self.relu22(x)
428
+ if return_tensor == 'feature2':
429
+ return x
430
+ if return_tensor == 'lpips':
431
+ features.append(x)
432
+
433
+ x = self.pool2(x)
434
+ if return_tensor == 'pool2':
435
+ return x
436
+
437
+ x = self.conv31(x)
438
+ x = self.relu31(x)
439
+ x = self.conv32(x)
440
+ x = self.relu32(x)
441
+ x = self.conv33(x)
442
+ x = self.relu33(x)
443
+ if return_tensor == 'feature3':
444
+ return x
445
+ if return_tensor == 'lpips':
446
+ features.append(x)
447
+
448
+ x = self.pool3(x)
449
+ if return_tensor == 'pool3':
450
+ return x
451
+
452
+ x = self.conv41(x)
453
+ x = self.relu41(x)
454
+ x = self.conv42(x)
455
+ x = self.relu42(x)
456
+ x = self.conv43(x)
457
+ x = self.relu43(x)
458
+ if return_tensor == 'feature4':
459
+ return x
460
+ if return_tensor == 'lpips':
461
+ features.append(x)
462
+
463
+ x = self.pool4(x)
464
+ if return_tensor == 'pool4':
465
+ return x
466
+
467
+ x = self.conv51(x)
468
+ x = self.relu51(x)
469
+ x = self.conv52(x)
470
+ x = self.relu52(x)
471
+ x = self.conv53(x)
472
+ x = self.relu53(x)
473
+ if return_tensor == 'feature5':
474
+ return x
475
+ if return_tensor == 'lpips':
476
+ features.append(x)
477
+
478
+ x = self.pool5(x)
479
+ if return_tensor == 'pool5':
480
+ return x
481
+
482
+ if return_tensor == 'lpips':
483
+ score = 0
484
+ assert len(features) == 5
485
+ for idx in range(5):
486
+ feature = features[idx]
487
+ norm = feature.norm(dim=1, keepdim=True)
488
+ feature = feature / (norm + 1e-10)
489
+ feature_x, feature_y = feature.chunk(2, dim=0)
490
+ diff = (feature_x - feature_y).square()
491
+ score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False)
492
+ return score.sum(dim=1, keepdim=False)
493
+
494
+ x = self.avgpool(x)
495
+ x = self.flatten(x)
496
+ if return_tensor == 'flatten':
497
+ return x
498
+
499
+ x = self.fc1(x)
500
+ x = self.fc1_relu(x)
501
+ x = self.fc1_dropout(x)
502
+ x = self.fc2(x)
503
+ x = self.fc2_relu(x)
504
+ x = self.fc2_dropout(x)
505
+ if return_tensor == 'feature':
506
+ return x
507
+
508
+ x = self.fc3(x)
509
+ if return_tensor == 'logits':
510
+ return x
511
+
512
+ x = self.out(x)
513
+ if return_tensor == 'prediction':
514
+ return x
515
+
516
+ raise NotImplementedError(f'Output tensor name `{return_tensor}` is '
517
+ f'not implemented!')
518
+
519
+ # pylint: enable=missing-function-docstring
models/pggan_discriminator.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of discriminator described in PGGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1710.10196.pdf
5
+
6
+ Official TensorFlow implementation:
7
+ https://github.com/tkarras/progressive_growing_of_gans
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ __all__ = ['PGGANDiscriminator']
17
+
18
+ # Resolutions allowed.
19
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
20
+
21
+ # Default gain factor for weight scaling.
22
+ _WSCALE_GAIN = np.sqrt(2.0)
23
+
24
+ # pylint: disable=missing-function-docstring
25
+
26
+ class PGGANDiscriminator(nn.Module):
27
+ """Defines the discriminator network in PGGAN.
28
+
29
+ NOTE: The discriminator takes images with `RGB` channel order and pixel
30
+ range [-1, 1] as inputs.
31
+
32
+ Settings for the network:
33
+
34
+ (1) resolution: The resolution of the input image.
35
+ (2) init_res: Smallest resolution of the convolutional backbone.
36
+ (default: 4)
37
+ (3) image_channels: Number of channels of the input image. (default: 3)
38
+ (4) label_dim: Dimension of the additional label for conditional generation.
39
+ In one-hot conditioning case, it is equal to the number of classes. If
40
+ set to 0, conditioning training will be disabled. (default: 0)
41
+ (5) fused_scale: Whether to fused `conv2d` and `downsample` together,
42
+ resulting in `conv2d` with strides. (default: False)
43
+ (6) use_wscale: Whether to use weight scaling. (default: True)
44
+ (7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
45
+ (8) mbstd_groups: Group size for the minibatch standard deviation layer.
46
+ `0` means disable. (default: 16)
47
+ (9) fmaps_base: Factor to control number of feature maps for each layer.
48
+ (default: 16 << 10)
49
+ (10) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
50
+ (11) eps: A small value to avoid divide overflow. (default: 1e-8)
51
+ """
52
+
53
+ def __init__(self,
54
+ resolution,
55
+ init_res=4,
56
+ image_channels=3,
57
+ label_dim=0,
58
+ fused_scale=False,
59
+ use_wscale=True,
60
+ wscale_gain=np.sqrt(2.0),
61
+ mbstd_groups=16,
62
+ fmaps_base=16 << 10,
63
+ fmaps_max=512,
64
+ eps=1e-8):
65
+ """Initializes with basic settings.
66
+
67
+ Raises:
68
+ ValueError: If the `resolution` is not supported.
69
+ """
70
+ super().__init__()
71
+
72
+ if resolution not in _RESOLUTIONS_ALLOWED:
73
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
74
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
75
+
76
+ self.init_res = init_res
77
+ self.init_res_log2 = int(np.log2(self.init_res))
78
+ self.resolution = resolution
79
+ self.final_res_log2 = int(np.log2(self.resolution))
80
+ self.image_channels = image_channels
81
+ self.label_dim = label_dim
82
+ self.fused_scale = fused_scale
83
+ self.use_wscale = use_wscale
84
+ self.wscale_gain = wscale_gain
85
+ self.mbstd_groups = mbstd_groups
86
+ self.fmaps_base = fmaps_base
87
+ self.fmaps_max = fmaps_max
88
+ self.eps = eps
89
+
90
+ # Level-of-details (used for progressive training).
91
+ self.register_buffer('lod', torch.zeros(()))
92
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
93
+
94
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
95
+ res = 2 ** res_log2
96
+ in_channels = self.get_nf(res)
97
+ out_channels = self.get_nf(res // 2)
98
+ block_idx = self.final_res_log2 - res_log2
99
+
100
+ # Input convolution layer for each resolution.
101
+ self.add_module(
102
+ f'input{block_idx}',
103
+ ConvLayer(in_channels=self.image_channels,
104
+ out_channels=in_channels,
105
+ kernel_size=1,
106
+ add_bias=True,
107
+ downsample=False,
108
+ fused_scale=False,
109
+ use_wscale=use_wscale,
110
+ wscale_gain=wscale_gain,
111
+ activation_type='lrelu'))
112
+ self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
113
+ f'FromRGB_lod{block_idx}/weight')
114
+ self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
115
+ f'FromRGB_lod{block_idx}/bias')
116
+
117
+ # Convolution block for each resolution (except the last one).
118
+ if res != self.init_res:
119
+ self.add_module(
120
+ f'layer{2 * block_idx}',
121
+ ConvLayer(in_channels=in_channels,
122
+ out_channels=in_channels,
123
+ kernel_size=3,
124
+ add_bias=True,
125
+ downsample=False,
126
+ fused_scale=False,
127
+ use_wscale=use_wscale,
128
+ wscale_gain=wscale_gain,
129
+ activation_type='lrelu'))
130
+ tf_layer0_name = 'Conv0'
131
+ self.add_module(
132
+ f'layer{2 * block_idx + 1}',
133
+ ConvLayer(in_channels=in_channels,
134
+ out_channels=out_channels,
135
+ kernel_size=3,
136
+ add_bias=True,
137
+ downsample=True,
138
+ fused_scale=fused_scale,
139
+ use_wscale=use_wscale,
140
+ wscale_gain=wscale_gain,
141
+ activation_type='lrelu'))
142
+ tf_layer1_name = 'Conv1_down' if fused_scale else 'Conv1'
143
+
144
+ # Convolution block for last resolution.
145
+ else:
146
+ self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups, eps=eps)
147
+ self.add_module(
148
+ f'layer{2 * block_idx}',
149
+ ConvLayer(
150
+ in_channels=in_channels + 1,
151
+ out_channels=in_channels,
152
+ kernel_size=3,
153
+ add_bias=True,
154
+ downsample=False,
155
+ fused_scale=False,
156
+ use_wscale=use_wscale,
157
+ wscale_gain=wscale_gain,
158
+ activation_type='lrelu'))
159
+ tf_layer0_name = 'Conv'
160
+ self.add_module(
161
+ f'layer{2 * block_idx + 1}',
162
+ DenseLayer(in_channels=in_channels * res * res,
163
+ out_channels=out_channels,
164
+ add_bias=True,
165
+ use_wscale=use_wscale,
166
+ wscale_gain=wscale_gain,
167
+ activation_type='lrelu'))
168
+ tf_layer1_name = 'Dense0'
169
+
170
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
171
+ f'{res}x{res}/{tf_layer0_name}/weight')
172
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
173
+ f'{res}x{res}/{tf_layer0_name}/bias')
174
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
175
+ f'{res}x{res}/{tf_layer1_name}/weight')
176
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
177
+ f'{res}x{res}/{tf_layer1_name}/bias')
178
+
179
+ # Final dense layer.
180
+ self.output = DenseLayer(in_channels=out_channels,
181
+ out_channels=1 + self.label_dim,
182
+ add_bias=True,
183
+ use_wscale=self.use_wscale,
184
+ wscale_gain=1.0,
185
+ activation_type='linear')
186
+ self.pth_to_tf_var_mapping['output.weight'] = (
187
+ f'{res}x{res}/Dense1/weight')
188
+ self.pth_to_tf_var_mapping['output.bias'] = (
189
+ f'{res}x{res}/Dense1/bias')
190
+
191
+ def get_nf(self, res):
192
+ """Gets number of feature maps according to the given resolution."""
193
+ return min(self.fmaps_base // res, self.fmaps_max)
194
+
195
+ def forward(self, image, lod=None):
196
+ expected_shape = (self.image_channels, self.resolution, self.resolution)
197
+ if image.ndim != 4 or image.shape[1:] != expected_shape:
198
+ raise ValueError(f'The input tensor should be with shape '
199
+ f'[batch_size, channel, height, width], where '
200
+ f'`channel` equals to {self.image_channels}, '
201
+ f'`height`, `width` equal to {self.resolution}!\n'
202
+ f'But `{image.shape}` is received!')
203
+
204
+ lod = self.lod.item() if lod is None else lod
205
+ if lod + self.init_res_log2 > self.final_res_log2:
206
+ raise ValueError(f'Maximum level-of-details (lod) is '
207
+ f'{self.final_res_log2 - self.init_res_log2}, '
208
+ f'but `{lod}` is received!')
209
+
210
+ lod = self.lod.item()
211
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
212
+ block_idx = current_lod = self.final_res_log2 - res_log2
213
+ if current_lod <= lod < current_lod + 1:
214
+ x = getattr(self, f'input{block_idx}')(image)
215
+ elif current_lod - 1 < lod < current_lod:
216
+ alpha = lod - np.floor(lod)
217
+ y = getattr(self, f'input{block_idx}')(image)
218
+ x = y * alpha + x * (1 - alpha)
219
+ if lod < current_lod + 1:
220
+ if res_log2 == self.init_res_log2:
221
+ x = self.mbstd(x)
222
+ x = getattr(self, f'layer{2 * block_idx}')(x)
223
+ x = getattr(self, f'layer{2 * block_idx + 1}')(x)
224
+ if lod > current_lod:
225
+ image = F.avg_pool2d(
226
+ image, kernel_size=2, stride=2, padding=0)
227
+ x = self.output(x)
228
+
229
+ return {'score': x}
230
+
231
+
232
+ class MiniBatchSTDLayer(nn.Module):
233
+ """Implements the minibatch standard deviation layer."""
234
+
235
+ def __init__(self, groups, eps):
236
+ super().__init__()
237
+ self.groups = groups
238
+ self.eps = eps
239
+
240
+ def extra_repr(self):
241
+ return f'groups={self.groups}, epsilon={self.eps}'
242
+
243
+ def forward(self, x):
244
+ if self.groups <= 1:
245
+ return x
246
+
247
+ N, C, H, W = x.shape
248
+ G = min(self.groups, N) # Number of groups.
249
+
250
+ y = x.reshape(G, -1, C, H, W) # [GnCHW]
251
+ y = y - y.mean(dim=0) # [GnCHW]
252
+ y = y.square().mean(dim=0) # [nCHW]
253
+ y = (y + self.eps).sqrt() # [nCHW]
254
+ y = y.mean(dim=(1, 2, 3), keepdim=True) # [n111]
255
+ y = y.repeat(G, 1, H, W) # [N1HW]
256
+ x = torch.cat([x, y], dim=1) # [N(C+1)HW]
257
+
258
+ return x
259
+
260
+
261
+ class DownsamplingLayer(nn.Module):
262
+ """Implements the downsampling layer.
263
+
264
+ Basically, this layer can be used to downsample feature maps with average
265
+ pooling.
266
+ """
267
+
268
+ def __init__(self, scale_factor):
269
+ super().__init__()
270
+ self.scale_factor = scale_factor
271
+
272
+ def extra_repr(self):
273
+ return f'factor={self.scale_factor}'
274
+
275
+ def forward(self, x):
276
+ if self.scale_factor <= 1:
277
+ return x
278
+ return F.avg_pool2d(x,
279
+ kernel_size=self.scale_factor,
280
+ stride=self.scale_factor,
281
+ padding=0)
282
+
283
+
284
+ class ConvLayer(nn.Module):
285
+ """Implements the convolutional layer.
286
+
287
+ Basically, this layer executes convolution, activation, and downsampling (if
288
+ needed) in sequence.
289
+ """
290
+
291
+ def __init__(self,
292
+ in_channels,
293
+ out_channels,
294
+ kernel_size,
295
+ add_bias,
296
+ downsample,
297
+ fused_scale,
298
+ use_wscale,
299
+ wscale_gain,
300
+ activation_type):
301
+ """Initializes with layer settings.
302
+
303
+ Args:
304
+ in_channels: Number of channels of the input tensor.
305
+ out_channels: Number of channels of the output tensor.
306
+ kernel_size: Size of the convolutional kernels.
307
+ add_bias: Whether to add bias onto the convolutional result.
308
+ downsample: Whether to downsample the result after convolution.
309
+ fused_scale: Whether to fused `conv2d` and `downsample` together,
310
+ resulting in `conv2d` with strides.
311
+ use_wscale: Whether to use weight scaling.
312
+ wscale_gain: Gain factor for weight scaling.
313
+ activation_type: Type of activation.
314
+ """
315
+ super().__init__()
316
+ self.in_channels = in_channels
317
+ self.out_channels = out_channels
318
+ self.kernel_size = kernel_size
319
+ self.add_bias = add_bias
320
+ self.downsample = downsample
321
+ self.fused_scale = fused_scale
322
+ self.use_wscale = use_wscale
323
+ self.wscale_gain = wscale_gain
324
+ self.activation_type = activation_type
325
+
326
+ if downsample and not fused_scale:
327
+ self.down = DownsamplingLayer(scale_factor=2)
328
+ else:
329
+ self.down = nn.Identity()
330
+
331
+ if downsample and fused_scale:
332
+ self.use_stride = True
333
+ self.stride = 2
334
+ self.padding = 1
335
+ else:
336
+ self.use_stride = False
337
+ self.stride = 1
338
+ self.padding = kernel_size // 2
339
+
340
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
341
+ fan_in = kernel_size * kernel_size * in_channels
342
+ wscale = wscale_gain / np.sqrt(fan_in)
343
+ if use_wscale:
344
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
345
+ self.wscale = wscale
346
+ else:
347
+ self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
348
+ self.wscale = 1.0
349
+
350
+ if add_bias:
351
+ self.bias = nn.Parameter(torch.zeros(out_channels))
352
+ else:
353
+ self.bias = None
354
+
355
+ assert activation_type in ['linear', 'relu', 'lrelu']
356
+
357
+ def extra_repr(self):
358
+ return (f'in_ch={self.in_channels}, '
359
+ f'out_ch={self.out_channels}, '
360
+ f'ksize={self.kernel_size}, '
361
+ f'wscale_gain={self.wscale_gain:.3f}, '
362
+ f'bias={self.add_bias}, '
363
+ f'downsample={self.scale_factor}, '
364
+ f'fused_scale={self.fused_scale}, '
365
+ f'act={self.activation_type}')
366
+
367
+ def forward(self, x):
368
+ weight = self.weight
369
+ if self.wscale != 1.0:
370
+ weight = weight * self.wscale
371
+
372
+ if self.use_stride:
373
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
374
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
375
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
376
+ x = F.conv2d(x,
377
+ weight=weight,
378
+ bias=self.bias,
379
+ stride=self.stride,
380
+ padding=self.padding)
381
+
382
+ if self.activation_type == 'linear':
383
+ pass
384
+ elif self.activation_type == 'relu':
385
+ x = F.relu(x, inplace=True)
386
+ elif self.activation_type == 'lrelu':
387
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
388
+ else:
389
+ raise NotImplementedError(f'Not implemented activation type '
390
+ f'`{self.activation_type}`!')
391
+ x = self.down(x)
392
+
393
+ return x
394
+
395
+
396
+ class DenseLayer(nn.Module):
397
+ """Implements the dense layer."""
398
+
399
+ def __init__(self,
400
+ in_channels,
401
+ out_channels,
402
+ add_bias,
403
+ use_wscale,
404
+ wscale_gain,
405
+ activation_type):
406
+ """Initializes with layer settings.
407
+
408
+ Args:
409
+ in_channels: Number of channels of the input tensor.
410
+ out_channels: Number of channels of the output tensor.
411
+ add_bias: Whether to add bias onto the fully-connected result.
412
+ use_wscale: Whether to use weight scaling.
413
+ wscale_gain: Gain factor for weight scaling.
414
+ activation_type: Type of activation.
415
+
416
+ Raises:
417
+ NotImplementedError: If the `activation_type` is not supported.
418
+ """
419
+ super().__init__()
420
+ self.in_channels = in_channels
421
+ self.out_channels = out_channels
422
+ self.add_bias = add_bias
423
+ self.use_wscale = use_wscale
424
+ self.wscale_gain = wscale_gain
425
+ self.activation_type = activation_type
426
+
427
+ weight_shape = (out_channels, in_channels)
428
+ wscale = wscale_gain / np.sqrt(in_channels)
429
+ if use_wscale:
430
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
431
+ self.wscale = wscale
432
+ else:
433
+ self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
434
+ self.wscale = 1.0
435
+
436
+ if add_bias:
437
+ self.bias = nn.Parameter(torch.zeros(out_channels))
438
+ else:
439
+ self.bias = None
440
+
441
+ assert activation_type in ['linear', 'relu', 'lrelu']
442
+
443
+ def forward(self, x):
444
+ if x.ndim != 2:
445
+ x = x.flatten(start_dim=1)
446
+
447
+ weight = self.weight
448
+ if self.wscale != 1.0:
449
+ weight = weight * self.wscale
450
+
451
+ x = F.linear(x, weight=weight, bias=self.bias)
452
+
453
+ if self.activation_type == 'linear':
454
+ pass
455
+ elif self.activation_type == 'relu':
456
+ x = F.relu(x, inplace=True)
457
+ elif self.activation_type == 'lrelu':
458
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
459
+ else:
460
+ raise NotImplementedError(f'Not implemented activation type '
461
+ f'`{self.activation_type}`!')
462
+
463
+ return x
464
+
465
+ # pylint: enable=missing-function-docstring
models/pggan_generator.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of generator described in PGGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1710.10196.pdf
5
+
6
+ Official TensorFlow implementation:
7
+ https://github.com/tkarras/progressive_growing_of_gans
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ __all__ = ['PGGANGenerator']
17
+
18
+ # Resolutions allowed.
19
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
20
+
21
+ # pylint: disable=missing-function-docstring
22
+
23
+ class PGGANGenerator(nn.Module):
24
+ """Defines the generator network in PGGAN.
25
+
26
+ NOTE: The synthesized images are with `RGB` channel order and pixel range
27
+ [-1, 1].
28
+
29
+ Settings for the network:
30
+
31
+ (1) resolution: The resolution of the output image.
32
+ (2) init_res: The initial resolution to start with convolution. (default: 4)
33
+ (3) z_dim: Dimension of the input latent space, Z. (default: 512)
34
+ (4) image_channels: Number of channels of the output image. (default: 3)
35
+ (5) final_tanh: Whether to use `tanh` to control the final pixel range.
36
+ (default: False)
37
+ (6) label_dim: Dimension of the additional label for conditional generation.
38
+ In one-hot conditioning case, it is equal to the number of classes. If
39
+ set to 0, conditioning training will be disabled. (default: 0)
40
+ (7) fused_scale: Whether to fused `upsample` and `conv2d` together,
41
+ resulting in `conv2d_transpose`. (default: False)
42
+ (8) use_wscale: Whether to use weight scaling. (default: True)
43
+ (9) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
44
+ (10) fmaps_base: Factor to control number of feature maps for each layer.
45
+ (default: 16 << 10)
46
+ (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
47
+ (12) eps: A small value to avoid divide overflow. (default: 1e-8)
48
+ """
49
+
50
+ def __init__(self,
51
+ resolution,
52
+ init_res=4,
53
+ z_dim=512,
54
+ image_channels=3,
55
+ final_tanh=False,
56
+ label_dim=0,
57
+ fused_scale=False,
58
+ use_wscale=True,
59
+ wscale_gain=np.sqrt(2.0),
60
+ fmaps_base=16 << 10,
61
+ fmaps_max=512,
62
+ eps=1e-8):
63
+ """Initializes with basic settings.
64
+
65
+ Raises:
66
+ ValueError: If the `resolution` is not supported.
67
+ """
68
+ super().__init__()
69
+
70
+ if resolution not in _RESOLUTIONS_ALLOWED:
71
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
72
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
73
+
74
+ self.init_res = init_res
75
+ self.init_res_log2 = int(np.log2(self.init_res))
76
+ self.resolution = resolution
77
+ self.final_res_log2 = int(np.log2(self.resolution))
78
+ self.z_dim = z_dim
79
+ self.image_channels = image_channels
80
+ self.final_tanh = final_tanh
81
+ self.label_dim = label_dim
82
+ self.fused_scale = fused_scale
83
+ self.use_wscale = use_wscale
84
+ self.wscale_gain = wscale_gain
85
+ self.fmaps_base = fmaps_base
86
+ self.fmaps_max = fmaps_max
87
+ self.eps = eps
88
+
89
+ # Dimension of latent space, which is convenient for sampling.
90
+ self.latent_dim = (self.z_dim,)
91
+
92
+ # Number of convolutional layers.
93
+ self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
94
+
95
+ # Level-of-details (used for progressive training).
96
+ self.register_buffer('lod', torch.zeros(()))
97
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
98
+
99
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
100
+ res = 2 ** res_log2
101
+ in_channels = self.get_nf(res // 2)
102
+ out_channels = self.get_nf(res)
103
+ block_idx = res_log2 - self.init_res_log2
104
+
105
+ # First convolution layer for each resolution.
106
+ if res == self.init_res:
107
+ self.add_module(
108
+ f'layer{2 * block_idx}',
109
+ ConvLayer(in_channels=z_dim + label_dim,
110
+ out_channels=out_channels,
111
+ kernel_size=init_res,
112
+ padding=init_res - 1,
113
+ add_bias=True,
114
+ upsample=False,
115
+ fused_scale=False,
116
+ use_wscale=use_wscale,
117
+ wscale_gain=wscale_gain,
118
+ activation_type='lrelu',
119
+ eps=eps))
120
+ tf_layer_name = 'Dense'
121
+ else:
122
+ self.add_module(
123
+ f'layer{2 * block_idx}',
124
+ ConvLayer(in_channels=in_channels,
125
+ out_channels=out_channels,
126
+ kernel_size=3,
127
+ padding=1,
128
+ add_bias=True,
129
+ upsample=True,
130
+ fused_scale=fused_scale,
131
+ use_wscale=use_wscale,
132
+ wscale_gain=wscale_gain,
133
+ activation_type='lrelu',
134
+ eps=eps))
135
+ tf_layer_name = 'Conv0_up' if fused_scale else 'Conv0'
136
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
137
+ f'{res}x{res}/{tf_layer_name}/weight')
138
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
139
+ f'{res}x{res}/{tf_layer_name}/bias')
140
+
141
+ # Second convolution layer for each resolution.
142
+ self.add_module(
143
+ f'layer{2 * block_idx + 1}',
144
+ ConvLayer(in_channels=out_channels,
145
+ out_channels=out_channels,
146
+ kernel_size=3,
147
+ padding=1,
148
+ add_bias=True,
149
+ upsample=False,
150
+ fused_scale=False,
151
+ use_wscale=use_wscale,
152
+ wscale_gain=wscale_gain,
153
+ activation_type='lrelu',
154
+ eps=eps))
155
+ tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
156
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
157
+ f'{res}x{res}/{tf_layer_name}/weight')
158
+ self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
159
+ f'{res}x{res}/{tf_layer_name}/bias')
160
+
161
+ # Output convolution layer for each resolution.
162
+ self.add_module(
163
+ f'output{block_idx}',
164
+ ConvLayer(in_channels=out_channels,
165
+ out_channels=image_channels,
166
+ kernel_size=1,
167
+ padding=0,
168
+ add_bias=True,
169
+ upsample=False,
170
+ fused_scale=False,
171
+ use_wscale=use_wscale,
172
+ wscale_gain=1.0,
173
+ activation_type='linear',
174
+ eps=eps))
175
+ self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
176
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
177
+ self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
178
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
179
+
180
+ def get_nf(self, res):
181
+ """Gets number of feature maps according to the given resolution."""
182
+ return min(self.fmaps_base // res, self.fmaps_max)
183
+
184
+ def forward(self, z, label=None, lod=None):
185
+ if z.ndim != 2 or z.shape[1] != self.z_dim:
186
+ raise ValueError(f'Input latent code should be with shape '
187
+ f'[batch_size, latent_dim], where '
188
+ f'`latent_dim` equals to {self.z_dim}!\n'
189
+ f'But `{z.shape}` is received!')
190
+ z = self.layer0.pixel_norm(z)
191
+ if self.label_dim:
192
+ if label is None:
193
+ raise ValueError(f'Model requires an additional label '
194
+ f'(with size {self.label_dim}) as input, '
195
+ f'but no label is received!')
196
+ if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
197
+ raise ValueError(f'Input label should be with shape '
198
+ f'[batch_size, label_dim], where '
199
+ f'`batch_size` equals to that of '
200
+ f'latent codes ({z.shape[0]}) and '
201
+ f'`label_dim` equals to {self.label_dim}!\n'
202
+ f'But `{label.shape}` is received!')
203
+ label = label.to(dtype=torch.float32)
204
+ z = torch.cat((z, label), dim=1)
205
+
206
+ lod = self.lod.item() if lod is None else lod
207
+ if lod + self.init_res_log2 > self.final_res_log2:
208
+ raise ValueError(f'Maximum level-of-details (lod) is '
209
+ f'{self.final_res_log2 - self.init_res_log2}, '
210
+ f'but `{lod}` is received!')
211
+
212
+ x = z.view(z.shape[0], self.z_dim + self.label_dim, 1, 1)
213
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
214
+ current_lod = self.final_res_log2 - res_log2
215
+ block_idx = res_log2 - self.init_res_log2
216
+ if lod < current_lod + 1:
217
+ x = getattr(self, f'layer{2 * block_idx}')(x)
218
+ x = getattr(self, f'layer{2 * block_idx + 1}')(x)
219
+ if current_lod - 1 < lod <= current_lod:
220
+ image = getattr(self, f'output{block_idx}')(x)
221
+ elif current_lod < lod < current_lod + 1:
222
+ alpha = np.ceil(lod) - lod
223
+ temp = getattr(self, f'output{block_idx}')(x)
224
+ image = F.interpolate(image, scale_factor=2, mode='nearest')
225
+ image = temp * alpha + image * (1 - alpha)
226
+ elif lod >= current_lod + 1:
227
+ image = F.interpolate(image, scale_factor=2, mode='nearest')
228
+ if self.final_tanh:
229
+ image = torch.tanh(image)
230
+
231
+ results = {
232
+ 'z': z,
233
+ 'label': label,
234
+ 'image': image,
235
+ }
236
+ return results
237
+
238
+
239
+ class PixelNormLayer(nn.Module):
240
+ """Implements pixel-wise feature vector normalization layer."""
241
+
242
+ def __init__(self, dim, eps):
243
+ super().__init__()
244
+ self.dim = dim
245
+ self.eps = eps
246
+
247
+ def extra_repr(self):
248
+ return f'dim={self.dim}, epsilon={self.eps}'
249
+
250
+ def forward(self, x):
251
+ scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
252
+ return x * scale
253
+
254
+
255
+ class UpsamplingLayer(nn.Module):
256
+ """Implements the upsampling layer.
257
+
258
+ Basically, this layer can be used to upsample feature maps with nearest
259
+ neighbor interpolation.
260
+ """
261
+
262
+ def __init__(self, scale_factor):
263
+ super().__init__()
264
+ self.scale_factor = scale_factor
265
+
266
+ def extra_repr(self):
267
+ return f'factor={self.scale_factor}'
268
+
269
+ def forward(self, x):
270
+ if self.scale_factor <= 1:
271
+ return x
272
+ return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
273
+
274
+
275
+ class ConvLayer(nn.Module):
276
+ """Implements the convolutional layer.
277
+
278
+ Basically, this layer executes pixel-wise normalization, upsampling (if
279
+ needed), convolution, and activation in sequence.
280
+ """
281
+
282
+ def __init__(self,
283
+ in_channels,
284
+ out_channels,
285
+ kernel_size,
286
+ padding,
287
+ add_bias,
288
+ upsample,
289
+ fused_scale,
290
+ use_wscale,
291
+ wscale_gain,
292
+ activation_type,
293
+ eps):
294
+ """Initializes with layer settings.
295
+
296
+ Args:
297
+ in_channels: Number of channels of the input tensor.
298
+ out_channels: Number of channels of the output tensor.
299
+ kernel_size: Size of the convolutional kernels.
300
+ padding: Padding used in convolution.
301
+ add_bias: Whether to add bias onto the convolutional result.
302
+ upsample: Whether to upsample the input tensor before convolution.
303
+ fused_scale: Whether to fused `upsample` and `conv2d` together,
304
+ resulting in `conv2d_transpose`.
305
+ use_wscale: Whether to use weight scaling.
306
+ wscale_gain: Gain factor for weight scaling.
307
+ activation_type: Type of activation.
308
+ eps: A small value to avoid divide overflow.
309
+ """
310
+ super().__init__()
311
+ self.in_channels = in_channels
312
+ self.out_channels = out_channels
313
+ self.kernel_size = kernel_size
314
+ self.padding = padding
315
+ self.add_bias = add_bias
316
+ self.upsample = upsample
317
+ self.fused_scale = fused_scale
318
+ self.use_wscale = use_wscale
319
+ self.wscale_gain = wscale_gain
320
+ self.activation_type = activation_type
321
+ self.eps = eps
322
+
323
+ self.pixel_norm = PixelNormLayer(dim=1, eps=eps)
324
+
325
+ if upsample and not fused_scale:
326
+ self.up = UpsamplingLayer(scale_factor=2)
327
+ else:
328
+ self.up = nn.Identity()
329
+
330
+ if upsample and fused_scale:
331
+ self.use_conv2d_transpose = True
332
+ weight_shape = (in_channels, out_channels, kernel_size, kernel_size)
333
+ self.stride = 2
334
+ self.padding = 1
335
+ else:
336
+ self.use_conv2d_transpose = False
337
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
338
+ self.stride = 1
339
+
340
+ fan_in = kernel_size * kernel_size * in_channels
341
+ wscale = wscale_gain / np.sqrt(fan_in)
342
+ if use_wscale:
343
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
344
+ self.wscale = wscale
345
+ else:
346
+ self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
347
+ self.wscale = 1.0
348
+
349
+ if add_bias:
350
+ self.bias = nn.Parameter(torch.zeros(out_channels))
351
+ else:
352
+ self.bias = None
353
+
354
+ assert activation_type in ['linear', 'relu', 'lrelu']
355
+
356
+ def extra_repr(self):
357
+ return (f'in_ch={self.in_channels}, '
358
+ f'out_ch={self.out_channels}, '
359
+ f'ksize={self.kernel_size}, '
360
+ f'padding={self.padding}, '
361
+ f'wscale_gain={self.wscale_gain:.3f}, '
362
+ f'bias={self.add_bias}, '
363
+ f'upsample={self.scale_factor}, '
364
+ f'fused_scale={self.fused_scale}, '
365
+ f'act={self.activation_type}')
366
+
367
+ def forward(self, x):
368
+ x = self.pixel_norm(x)
369
+ x = self.up(x)
370
+ weight = self.weight
371
+ if self.wscale != 1.0:
372
+ weight = weight * self.wscale
373
+ if self.use_conv2d_transpose:
374
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
375
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
376
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
377
+ x = F.conv_transpose2d(x,
378
+ weight=weight,
379
+ bias=self.bias,
380
+ stride=self.stride,
381
+ padding=self.padding)
382
+ else:
383
+ x = F.conv2d(x,
384
+ weight=weight,
385
+ bias=self.bias,
386
+ stride=self.stride,
387
+ padding=self.padding)
388
+
389
+ if self.activation_type == 'linear':
390
+ pass
391
+ elif self.activation_type == 'relu':
392
+ x = F.relu(x, inplace=True)
393
+ elif self.activation_type == 'lrelu':
394
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
395
+ else:
396
+ raise NotImplementedError(f'Not implemented activation type '
397
+ f'`{self.activation_type}`!')
398
+
399
+ return x
400
+
401
+ # pylint: enable=missing-function-docstring
models/stylegan2_discriminator.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of discriminator described in StyleGAN2.
3
+
4
+ Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip
5
+ connections, increases model size and disables progressive growth. This script
6
+ ONLY supports config F in the original paper.
7
+
8
+ Paper: https://arxiv.org/pdf/1912.04958.pdf
9
+
10
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
11
+ """
12
+
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from third_party.stylegan2_official_ops import bias_act
19
+ from third_party.stylegan2_official_ops import upfirdn2d
20
+ from third_party.stylegan2_official_ops import conv2d_gradfix
21
+
22
+ __all__ = ['StyleGAN2Discriminator']
23
+
24
+ # Resolutions allowed.
25
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
26
+
27
+ # Architectures allowed.
28
+ _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
29
+
30
+ # pylint: disable=missing-function-docstring
31
+
32
+ class StyleGAN2Discriminator(nn.Module):
33
+ """Defines the discriminator network in StyleGAN2.
34
+
35
+ NOTE: The discriminator takes images with `RGB` channel order and pixel
36
+ range [-1, 1] as inputs.
37
+
38
+ Settings for the backbone:
39
+
40
+ (1) resolution: The resolution of the input image. (default: -1)
41
+ (2) init_res: Smallest resolution of the convolutional backbone.
42
+ (default: 4)
43
+ (3) image_channels: Number of channels of the input image. (default: 3)
44
+ (4) architecture: Type of architecture. Support `origin`, `skip`, and
45
+ `resnet`. (default: `resnet`)
46
+ (5) use_wscale: Whether to use weight scaling. (default: True)
47
+ (6) wscale_gain: The factor to control weight scaling. (default: 1.0)
48
+ (7) lr_mul: Learning rate multiplier for backbone. (default: 1.0)
49
+ (8) mbstd_groups: Group size for the minibatch standard deviation layer.
50
+ `0` means disable. (default: 4)
51
+ (9) mbstd_channels: Number of new channels (appended to the original feature
52
+ map) after the minibatch standard deviation layer. (default: 1)
53
+ (10) fmaps_base: Factor to control number of feature maps for each layer.
54
+ (default: 32 << 10)
55
+ (11) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
56
+ (12) filter_kernel: Kernel used for filtering (e.g., downsampling).
57
+ (default: (1, 3, 3, 1))
58
+ (13) conv_clamp: A threshold to clamp the output of convolution layers to
59
+ avoid overflow under FP16 training. (default: None)
60
+ (14) eps: A small value to avoid divide overflow. (default: 1e-8)
61
+
62
+ Settings for conditional model:
63
+
64
+ (1) label_dim: Dimension of the additional label for conditional generation.
65
+ In one-hot conditioning case, it is equal to the number of classes. If
66
+ set to 0, conditioning training will be disabled. (default: 0)
67
+ (2) embedding_dim: Dimension of the embedding space, if needed.
68
+ (default: 512)
69
+ (3) embedding_bias: Whether to add bias to embedding learning.
70
+ (default: True)
71
+ (4) embedding_use_wscale: Whether to use weight scaling for embedding
72
+ learning. (default: True)
73
+ (5) embedding_lr_mul: Learning rate multiplier for the embedding learning.
74
+ (default: 1.0)
75
+ (6) normalize_embedding: Whether to normalize the embedding. (default: True)
76
+ (7) mapping_layers: Number of layers of the additional mapping network after
77
+ embedding. (default: 0)
78
+ (8) mapping_fmaps: Number of hidden channels of the additional mapping
79
+ network after embedding. (default: 512)
80
+ (9) mapping_use_wscale: Whether to use weight scaling for the additional
81
+ mapping network. (default: True)
82
+ (10) mapping_lr_mul: Learning rate multiplier for the additional mapping
83
+ network after embedding. (default: 0.1)
84
+
85
+ Runtime settings:
86
+
87
+ (1) fp16_res: Layers at resolution higher than (or equal to) this field will
88
+ use `float16` precision for computation. This is merely used for
89
+ acceleration. If set as `None`, all layers will use `float32` by
90
+ default. (default: None)
91
+ (2) impl: Implementation mode of some particular ops, e.g., `filtering`,
92
+ `bias_act`, etc. `cuda` means using the official CUDA implementation
93
+ from StyleGAN2, while `ref` means using the native PyTorch ops.
94
+ (default: `cuda`)
95
+ """
96
+
97
+ def __init__(self,
98
+ # Settings for backbone.
99
+ resolution=-1,
100
+ init_res=4,
101
+ image_channels=3,
102
+ architecture='resnet',
103
+ use_wscale=True,
104
+ wscale_gain=1.0,
105
+ lr_mul=1.0,
106
+ mbstd_groups=4,
107
+ mbstd_channels=1,
108
+ fmaps_base=32 << 10,
109
+ fmaps_max=512,
110
+ filter_kernel=(1, 3, 3, 1),
111
+ conv_clamp=None,
112
+ eps=1e-8,
113
+ # Settings for conditional model.
114
+ label_dim=0,
115
+ embedding_dim=512,
116
+ embedding_bias=True,
117
+ embedding_use_wscale=True,
118
+ embedding_lr_mul=1.0,
119
+ normalize_embedding=True,
120
+ mapping_layers=0,
121
+ mapping_fmaps=512,
122
+ mapping_use_wscale=True,
123
+ mapping_lr_mul=0.1):
124
+ """Initializes with basic settings.
125
+
126
+ Raises:
127
+ ValueError: If the `resolution` is not supported, or `architecture`
128
+ is not supported.
129
+ """
130
+ super().__init__()
131
+
132
+ if resolution not in _RESOLUTIONS_ALLOWED:
133
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
134
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
135
+ architecture = architecture.lower()
136
+ if architecture not in _ARCHITECTURES_ALLOWED:
137
+ raise ValueError(f'Invalid architecture: `{architecture}`!\n'
138
+ f'Architectures allowed: '
139
+ f'{_ARCHITECTURES_ALLOWED}.')
140
+
141
+ self.init_res = init_res
142
+ self.init_res_log2 = int(np.log2(init_res))
143
+ self.resolution = resolution
144
+ self.final_res_log2 = int(np.log2(resolution))
145
+ self.image_channels = image_channels
146
+ self.architecture = architecture
147
+ self.use_wscale = use_wscale
148
+ self.wscale_gain = wscale_gain
149
+ self.lr_mul = lr_mul
150
+ self.mbstd_groups = mbstd_groups
151
+ self.mbstd_channels = mbstd_channels
152
+ self.fmaps_base = fmaps_base
153
+ self.fmaps_max = fmaps_max
154
+ self.filter_kernel = filter_kernel
155
+ self.conv_clamp = conv_clamp
156
+ self.eps = eps
157
+
158
+ self.label_dim = label_dim
159
+ self.embedding_dim = embedding_dim
160
+ self.embedding_bias = embedding_bias
161
+ self.embedding_use_wscale = embedding_use_wscale
162
+ self.embedding_lr_mul = embedding_lr_mul
163
+ self.normalize_embedding = normalize_embedding
164
+ self.mapping_layers = mapping_layers
165
+ self.mapping_fmaps = mapping_fmaps
166
+ self.mapping_use_wscale = mapping_use_wscale
167
+ self.mapping_lr_mul = mapping_lr_mul
168
+
169
+ self.pth_to_tf_var_mapping = {}
170
+
171
+ # Embedding for conditional discrimination.
172
+ self.use_embedding = label_dim > 0 and embedding_dim > 0
173
+ if self.use_embedding:
174
+ self.embedding = DenseLayer(in_channels=label_dim,
175
+ out_channels=embedding_dim,
176
+ add_bias=embedding_bias,
177
+ init_bias=0.0,
178
+ use_wscale=embedding_use_wscale,
179
+ wscale_gain=wscale_gain,
180
+ lr_mul=embedding_lr_mul,
181
+ activation_type='linear')
182
+ self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight'
183
+ if self.embedding_bias:
184
+ self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias'
185
+
186
+ if self.normalize_embedding:
187
+ self.norm = PixelNormLayer(dim=1, eps=eps)
188
+
189
+ for i in range(mapping_layers):
190
+ in_channels = (embedding_dim if i == 0 else mapping_fmaps)
191
+ out_channels = (embedding_dim if i == (mapping_layers - 1) else
192
+ mapping_fmaps)
193
+ layer_name = f'mapping{i}'
194
+ self.add_module(layer_name,
195
+ DenseLayer(in_channels=in_channels,
196
+ out_channels=out_channels,
197
+ add_bias=True,
198
+ init_bias=0.0,
199
+ use_wscale=mapping_use_wscale,
200
+ wscale_gain=wscale_gain,
201
+ lr_mul=mapping_lr_mul,
202
+ activation_type='lrelu'))
203
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
204
+ f'Mapping{i}/weight')
205
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
206
+ f'Mapping{i}/bias')
207
+
208
+ # Convolutional backbone.
209
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
210
+ res = 2 ** res_log2
211
+ in_channels = self.get_nf(res)
212
+ out_channels = self.get_nf(res // 2)
213
+ block_idx = self.final_res_log2 - res_log2
214
+
215
+ # Input convolution layer for each resolution (if needed).
216
+ if res_log2 == self.final_res_log2 or self.architecture == 'skip':
217
+ layer_name = f'input{block_idx}'
218
+ self.add_module(layer_name,
219
+ ConvLayer(in_channels=image_channels,
220
+ out_channels=in_channels,
221
+ kernel_size=1,
222
+ add_bias=True,
223
+ scale_factor=1,
224
+ filter_kernel=None,
225
+ use_wscale=use_wscale,
226
+ wscale_gain=wscale_gain,
227
+ lr_mul=lr_mul,
228
+ activation_type='lrelu',
229
+ conv_clamp=conv_clamp))
230
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
231
+ f'{res}x{res}/FromRGB/weight')
232
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
233
+ f'{res}x{res}/FromRGB/bias')
234
+
235
+ # Convolution block for each resolution (except the last one).
236
+ if res != self.init_res:
237
+ # First layer (kernel 3x3) without downsampling.
238
+ layer_name = f'layer{2 * block_idx}'
239
+ self.add_module(layer_name,
240
+ ConvLayer(in_channels=in_channels,
241
+ out_channels=in_channels,
242
+ kernel_size=3,
243
+ add_bias=True,
244
+ scale_factor=1,
245
+ filter_kernel=None,
246
+ use_wscale=use_wscale,
247
+ wscale_gain=wscale_gain,
248
+ lr_mul=lr_mul,
249
+ activation_type='lrelu',
250
+ conv_clamp=conv_clamp))
251
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
252
+ f'{res}x{res}/Conv0/weight')
253
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
254
+ f'{res}x{res}/Conv0/bias')
255
+
256
+ # Second layer (kernel 3x3) with downsampling
257
+ layer_name = f'layer{2 * block_idx + 1}'
258
+ self.add_module(layer_name,
259
+ ConvLayer(in_channels=in_channels,
260
+ out_channels=out_channels,
261
+ kernel_size=3,
262
+ add_bias=True,
263
+ scale_factor=2,
264
+ filter_kernel=filter_kernel,
265
+ use_wscale=use_wscale,
266
+ wscale_gain=wscale_gain,
267
+ lr_mul=lr_mul,
268
+ activation_type='lrelu',
269
+ conv_clamp=conv_clamp))
270
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
271
+ f'{res}x{res}/Conv1_down/weight')
272
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
273
+ f'{res}x{res}/Conv1_down/bias')
274
+
275
+ # Residual branch (kernel 1x1) with downsampling, without bias,
276
+ # with linear activation.
277
+ if self.architecture == 'resnet':
278
+ layer_name = f'residual{block_idx}'
279
+ self.add_module(layer_name,
280
+ ConvLayer(in_channels=in_channels,
281
+ out_channels=out_channels,
282
+ kernel_size=1,
283
+ add_bias=False,
284
+ scale_factor=2,
285
+ filter_kernel=filter_kernel,
286
+ use_wscale=use_wscale,
287
+ wscale_gain=wscale_gain,
288
+ lr_mul=lr_mul,
289
+ activation_type='linear',
290
+ conv_clamp=None))
291
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
292
+ f'{res}x{res}/Skip/weight')
293
+
294
+ # Convolution block for last resolution.
295
+ else:
296
+ self.mbstd = MiniBatchSTDLayer(
297
+ groups=mbstd_groups, new_channels=mbstd_channels, eps=eps)
298
+
299
+ # First layer (kernel 3x3) without downsampling.
300
+ layer_name = f'layer{2 * block_idx}'
301
+ self.add_module(
302
+ layer_name,
303
+ ConvLayer(in_channels=in_channels + mbstd_channels,
304
+ out_channels=in_channels,
305
+ kernel_size=3,
306
+ add_bias=True,
307
+ scale_factor=1,
308
+ filter_kernel=None,
309
+ use_wscale=use_wscale,
310
+ wscale_gain=wscale_gain,
311
+ lr_mul=lr_mul,
312
+ activation_type='lrelu',
313
+ conv_clamp=conv_clamp))
314
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
315
+ f'{res}x{res}/Conv/weight')
316
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
317
+ f'{res}x{res}/Conv/bias')
318
+
319
+ # Second layer, as a fully-connected layer.
320
+ layer_name = f'layer{2 * block_idx + 1}'
321
+ self.add_module(layer_name,
322
+ DenseLayer(in_channels=in_channels * res * res,
323
+ out_channels=in_channels,
324
+ add_bias=True,
325
+ init_bias=0.0,
326
+ use_wscale=use_wscale,
327
+ wscale_gain=wscale_gain,
328
+ lr_mul=lr_mul,
329
+ activation_type='lrelu'))
330
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
331
+ f'{res}x{res}/Dense0/weight')
332
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
333
+ f'{res}x{res}/Dense0/bias')
334
+
335
+ # Final dense layer to output score.
336
+ self.output = DenseLayer(in_channels=in_channels,
337
+ out_channels=(embedding_dim
338
+ if self.use_embedding
339
+ else max(label_dim, 1)),
340
+ add_bias=True,
341
+ init_bias=0.0,
342
+ use_wscale=use_wscale,
343
+ wscale_gain=wscale_gain,
344
+ lr_mul=lr_mul,
345
+ activation_type='linear')
346
+ self.pth_to_tf_var_mapping['output.weight'] = 'Output/weight'
347
+ self.pth_to_tf_var_mapping['output.bias'] = 'Output/bias'
348
+
349
+ # Used for downsampling input image for `skip` architecture.
350
+ if self.architecture == 'skip':
351
+ self.register_buffer(
352
+ 'filter', upfirdn2d.setup_filter(filter_kernel))
353
+
354
+ def get_nf(self, res):
355
+ """Gets number of feature maps according to the given resolution."""
356
+ return min(self.fmaps_base // res, self.fmaps_max)
357
+
358
+ def forward(self, image, label=None, fp16_res=None, impl='cuda'):
359
+ # Check shape.
360
+ expected_shape = (self.image_channels, self.resolution, self.resolution)
361
+ if image.ndim != 4 or image.shape[1:] != expected_shape:
362
+ raise ValueError(f'The input tensor should be with shape '
363
+ f'[batch_size, channel, height, width], where '
364
+ f'`channel` equals to {self.image_channels}, '
365
+ f'`height`, `width` equal to {self.resolution}!\n'
366
+ f'But `{image.shape}` is received!')
367
+ if self.label_dim > 0:
368
+ if label is None:
369
+ raise ValueError(f'Model requires an additional label '
370
+ f'(with dimension {self.label_dim}) as input, '
371
+ f'but no label is received!')
372
+ batch_size = image.shape[0]
373
+ if label.ndim != 2 or label.shape != (batch_size, self.label_dim):
374
+ raise ValueError(f'Input label should be with shape '
375
+ f'[batch_size, label_dim], where '
376
+ f'`batch_size` equals to that of '
377
+ f'images ({image.shape[0]}) and '
378
+ f'`label_dim` equals to {self.label_dim}!\n'
379
+ f'But `{label.shape}` is received!')
380
+ label = label.to(dtype=torch.float32)
381
+ if self.use_embedding:
382
+ embed = self.embedding(label, impl=impl)
383
+ if self.normalize_embedding:
384
+ embed = self.norm(embed)
385
+ for i in range(self.mapping_layers):
386
+ embed = getattr(self, f'mapping{i}')(embed, impl=impl)
387
+
388
+ # Cast to `torch.float16` if needed.
389
+ if fp16_res is not None and self.resolution >= fp16_res:
390
+ image = image.to(torch.float16)
391
+
392
+ x = self.input0(image, impl=impl)
393
+
394
+ for res_log2 in range(self.final_res_log2, self.init_res_log2, -1):
395
+ res = 2 ** res_log2
396
+ # Cast to `torch.float16` if needed.
397
+ if fp16_res is not None and res >= fp16_res:
398
+ x = x.to(torch.float16)
399
+ else:
400
+ x = x.to(torch.float32)
401
+
402
+ idx = self.final_res_log2 - res_log2 # Block index
403
+
404
+ if self.architecture == 'skip' and idx > 0:
405
+ image = upfirdn2d.downsample2d(image, self.filter, impl=impl)
406
+ # Cast to `torch.float16` if needed.
407
+ if fp16_res is not None and res >= fp16_res:
408
+ image = image.to(torch.float16)
409
+ else:
410
+ image = image.to(torch.float32)
411
+ y = getattr(self, f'input{idx}')(image, impl=impl)
412
+ x = x + y
413
+
414
+ if self.architecture == 'resnet':
415
+ residual = getattr(self, f'residual{idx}')(
416
+ x, runtime_gain=np.sqrt(0.5), impl=impl)
417
+ x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
418
+ x = getattr(self, f'layer{2 * idx + 1}')(
419
+ x, runtime_gain=np.sqrt(0.5), impl=impl)
420
+ x = x + residual
421
+ else:
422
+ x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
423
+ x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl)
424
+
425
+ # Final output.
426
+ idx += 1
427
+ if fp16_res is not None: # Always use FP32 for the last block.
428
+ x = x.to(torch.float32)
429
+ if self.architecture == 'skip':
430
+ image = upfirdn2d.downsample2d(image, self.filter, impl=impl)
431
+ if fp16_res is not None: # Always use FP32 for the last block.
432
+ image = image.to(torch.float32)
433
+ y = getattr(self, f'input{idx}')(image, impl=impl)
434
+ x = x + y
435
+ x = self.mbstd(x)
436
+ x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
437
+ x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl)
438
+ x = self.output(x, impl=impl)
439
+
440
+ if self.use_embedding:
441
+ x = (x * embed).sum(dim=1, keepdim=True)
442
+ x = x / np.sqrt(self.embedding_dim)
443
+ elif self.label_dim > 0:
444
+ x = (x * label).sum(dim=1, keepdim=True)
445
+
446
+ results = {
447
+ 'score': x,
448
+ 'label': label
449
+ }
450
+ if self.use_embedding:
451
+ results['embedding'] = embed
452
+ return results
453
+
454
+
455
+ class PixelNormLayer(nn.Module):
456
+ """Implements pixel-wise feature vector normalization layer."""
457
+
458
+ def __init__(self, dim, eps):
459
+ super().__init__()
460
+ self.dim = dim
461
+ self.eps = eps
462
+
463
+ def extra_repr(self):
464
+ return f'dim={self.dim}, epsilon={self.eps}'
465
+
466
+ def forward(self, x):
467
+ scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
468
+ return x * scale
469
+
470
+
471
+ class MiniBatchSTDLayer(nn.Module):
472
+ """Implements the minibatch standard deviation layer."""
473
+
474
+ def __init__(self, groups, new_channels, eps):
475
+ super().__init__()
476
+ self.groups = groups
477
+ self.new_channels = new_channels
478
+ self.eps = eps
479
+
480
+ def extra_repr(self):
481
+ return (f'groups={self.groups}, '
482
+ f'new_channels={self.new_channels}, '
483
+ f'epsilon={self.eps}')
484
+
485
+ def forward(self, x):
486
+ if self.groups <= 1 or self.new_channels < 1:
487
+ return x
488
+
489
+ dtype = x.dtype
490
+
491
+ N, C, H, W = x.shape
492
+ G = min(self.groups, N) # Number of groups.
493
+ nC = self.new_channels # Number of channel groups.
494
+ c = C // nC # Channels per channel group.
495
+
496
+ y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW]
497
+ y = y - y.mean(dim=0) # [GnFcHW]
498
+ y = y.square().mean(dim=0) # [nFcHW]
499
+ y = (y + self.eps).sqrt() # [nFcHW]
500
+ y = y.mean(dim=(2, 3, 4)) # [nF]
501
+ y = y.reshape(-1, nC, 1, 1) # [nF11]
502
+ y = y.repeat(G, 1, H, W) # [NFHW]
503
+ x = torch.cat((x, y), dim=1) # [N(C+F)HW]
504
+
505
+ assert x.dtype == dtype
506
+ return x
507
+
508
+
509
+ class ConvLayer(nn.Module):
510
+ """Implements the convolutional layer.
511
+
512
+ If downsampling is needed (i.e., `scale_factor = 2`), the feature map will
513
+ be filtered with `filter_kernel` first.
514
+ """
515
+
516
+ def __init__(self,
517
+ in_channels,
518
+ out_channels,
519
+ kernel_size,
520
+ add_bias,
521
+ scale_factor,
522
+ filter_kernel,
523
+ use_wscale,
524
+ wscale_gain,
525
+ lr_mul,
526
+ activation_type,
527
+ conv_clamp):
528
+ """Initializes with layer settings.
529
+
530
+ Args:
531
+ in_channels: Number of channels of the input tensor.
532
+ out_channels: Number of channels of the output tensor.
533
+ kernel_size: Size of the convolutional kernels.
534
+ add_bias: Whether to add bias onto the convolutional result.
535
+ scale_factor: Scale factor for downsampling. `1` means skip
536
+ downsampling.
537
+ filter_kernel: Kernel used for filtering.
538
+ use_wscale: Whether to use weight scaling.
539
+ wscale_gain: Gain factor for weight scaling.
540
+ lr_mul: Learning multiplier for both weight and bias.
541
+ activation_type: Type of activation.
542
+ conv_clamp: A threshold to clamp the output of convolution layers to
543
+ avoid overflow under FP16 training.
544
+ """
545
+ super().__init__()
546
+ self.in_channels = in_channels
547
+ self.out_channels = out_channels
548
+ self.kernel_size = kernel_size
549
+ self.add_bias = add_bias
550
+ self.scale_factor = scale_factor
551
+ self.filter_kernel = filter_kernel
552
+ self.use_wscale = use_wscale
553
+ self.wscale_gain = wscale_gain
554
+ self.lr_mul = lr_mul
555
+ self.activation_type = activation_type
556
+ self.conv_clamp = conv_clamp
557
+
558
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
559
+ fan_in = kernel_size * kernel_size * in_channels
560
+ wscale = wscale_gain / np.sqrt(fan_in)
561
+ if use_wscale:
562
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
563
+ self.wscale = wscale * lr_mul
564
+ else:
565
+ self.weight = nn.Parameter(
566
+ torch.randn(*weight_shape) * wscale / lr_mul)
567
+ self.wscale = lr_mul
568
+
569
+ if add_bias:
570
+ self.bias = nn.Parameter(torch.zeros(out_channels))
571
+ self.bscale = lr_mul
572
+ else:
573
+ self.bias = None
574
+ self.act_gain = bias_act.activation_funcs[activation_type].def_gain
575
+
576
+ if scale_factor > 1:
577
+ assert filter_kernel is not None
578
+ self.register_buffer(
579
+ 'filter', upfirdn2d.setup_filter(filter_kernel))
580
+ fh, fw = self.filter.shape
581
+ self.filter_padding = (
582
+ kernel_size // 2 + (fw - scale_factor + 1) // 2,
583
+ kernel_size // 2 + (fw - scale_factor) // 2,
584
+ kernel_size // 2 + (fh - scale_factor + 1) // 2,
585
+ kernel_size // 2 + (fh - scale_factor) // 2)
586
+
587
+ def extra_repr(self):
588
+ return (f'in_ch={self.in_channels}, '
589
+ f'out_ch={self.out_channels}, '
590
+ f'ksize={self.kernel_size}, '
591
+ f'wscale_gain={self.wscale_gain:.3f}, '
592
+ f'bias={self.add_bias}, '
593
+ f'lr_mul={self.lr_mul:.3f}, '
594
+ f'downsample={self.scale_factor}, '
595
+ f'downsample_filter={self.filter_kernel}, '
596
+ f'act={self.activation_type}, '
597
+ f'clamp={self.conv_clamp}')
598
+
599
+ def forward(self, x, runtime_gain=1.0, impl='cuda'):
600
+ dtype = x.dtype
601
+
602
+ weight = self.weight
603
+ if self.wscale != 1.0:
604
+ weight = weight * self.wscale
605
+ bias = None
606
+ if self.bias is not None:
607
+ bias = self.bias.to(dtype)
608
+ if self.bscale != 1.0:
609
+ bias = bias * self.bscale
610
+
611
+ if self.scale_factor == 1: # Native convolution without downsampling.
612
+ padding = self.kernel_size // 2
613
+ x = conv2d_gradfix.conv2d(
614
+ x, weight.to(dtype), stride=1, padding=padding, impl=impl)
615
+ else: # Convolution with downsampling.
616
+ down = self.scale_factor
617
+ f = self.filter
618
+ padding = self.filter_padding
619
+ # When kernel size = 1, use filtering function for downsampling.
620
+ if self.kernel_size == 1:
621
+ x = upfirdn2d.upfirdn2d(
622
+ x, f, down=down, padding=padding, impl=impl)
623
+ x = conv2d_gradfix.conv2d(
624
+ x, weight.to(dtype), stride=1, padding=0, impl=impl)
625
+ # When kernel size != 1, use stride convolution for downsampling.
626
+ else:
627
+ x = upfirdn2d.upfirdn2d(
628
+ x, f, down=1, padding=padding, impl=impl)
629
+ x = conv2d_gradfix.conv2d(
630
+ x, weight.to(dtype), stride=down, padding=0, impl=impl)
631
+
632
+ act_gain = self.act_gain * runtime_gain
633
+ act_clamp = None
634
+ if self.conv_clamp is not None:
635
+ act_clamp = self.conv_clamp * runtime_gain
636
+ x = bias_act.bias_act(x, bias,
637
+ act=self.activation_type,
638
+ gain=act_gain,
639
+ clamp=act_clamp,
640
+ impl=impl)
641
+
642
+ assert x.dtype == dtype
643
+ return x
644
+
645
+
646
+ class DenseLayer(nn.Module):
647
+ """Implements the dense layer."""
648
+
649
+ def __init__(self,
650
+ in_channels,
651
+ out_channels,
652
+ add_bias,
653
+ init_bias,
654
+ use_wscale,
655
+ wscale_gain,
656
+ lr_mul,
657
+ activation_type):
658
+ """Initializes with layer settings.
659
+
660
+ Args:
661
+ in_channels: Number of channels of the input tensor.
662
+ out_channels: Number of channels of the output tensor.
663
+ add_bias: Whether to add bias onto the fully-connected result.
664
+ init_bias: The initial bias value before training.
665
+ use_wscale: Whether to use weight scaling.
666
+ wscale_gain: Gain factor for weight scaling.
667
+ lr_mul: Learning multiplier for both weight and bias.
668
+ activation_type: Type of activation.
669
+ """
670
+ super().__init__()
671
+ self.in_channels = in_channels
672
+ self.out_channels = out_channels
673
+ self.add_bias = add_bias
674
+ self.init_bias = init_bias
675
+ self.use_wscale = use_wscale
676
+ self.wscale_gain = wscale_gain
677
+ self.lr_mul = lr_mul
678
+ self.activation_type = activation_type
679
+
680
+ weight_shape = (out_channels, in_channels)
681
+ wscale = wscale_gain / np.sqrt(in_channels)
682
+ if use_wscale:
683
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
684
+ self.wscale = wscale * lr_mul
685
+ else:
686
+ self.weight = nn.Parameter(
687
+ torch.randn(*weight_shape) * wscale / lr_mul)
688
+ self.wscale = lr_mul
689
+
690
+ if add_bias:
691
+ init_bias = np.float32(init_bias) / lr_mul
692
+ self.bias = nn.Parameter(torch.full([out_channels], init_bias))
693
+ self.bscale = lr_mul
694
+ else:
695
+ self.bias = None
696
+
697
+ def extra_repr(self):
698
+ return (f'in_ch={self.in_channels}, '
699
+ f'out_ch={self.out_channels}, '
700
+ f'wscale_gain={self.wscale_gain:.3f}, '
701
+ f'bias={self.add_bias}, '
702
+ f'init_bias={self.init_bias}, '
703
+ f'lr_mul={self.lr_mul:.3f}, '
704
+ f'act={self.activation_type}')
705
+
706
+ def forward(self, x, impl='cuda'):
707
+ dtype = x.dtype
708
+
709
+ if x.ndim != 2:
710
+ x = x.flatten(start_dim=1)
711
+
712
+ weight = self.weight.to(dtype) * self.wscale
713
+ bias = None
714
+ if self.bias is not None:
715
+ bias = self.bias.to(dtype)
716
+ if self.bscale != 1.0:
717
+ bias = bias * self.bscale
718
+
719
+ # Fast pass for linear activation.
720
+ if self.activation_type == 'linear' and bias is not None:
721
+ x = torch.addmm(bias.unsqueeze(0), x, weight.t())
722
+ else:
723
+ x = x.matmul(weight.t())
724
+ x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl)
725
+
726
+ assert x.dtype == dtype
727
+ return x
728
+
729
+ # pylint: enable=missing-function-docstring
models/stylegan2_generator.py ADDED
@@ -0,0 +1,1394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of generator described in StyleGAN2.
3
+
4
+ Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style
5
+ demodulation, adds skip connections, increases model size, and disables
6
+ progressive growth. This script ONLY supports config F in the original paper.
7
+
8
+ Paper: https://arxiv.org/pdf/1912.04958.pdf
9
+
10
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
11
+ """
12
+
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from third_party.stylegan2_official_ops import fma
19
+ from third_party.stylegan2_official_ops import bias_act
20
+ from third_party.stylegan2_official_ops import upfirdn2d
21
+ from third_party.stylegan2_official_ops import conv2d_gradfix
22
+ from .utils.ops import all_gather
23
+
24
+ __all__ = ['StyleGAN2Generator']
25
+
26
+ # Resolutions allowed.
27
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
28
+
29
+ # Architectures allowed.
30
+ _ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
31
+
32
+ # pylint: disable=missing-function-docstring
33
+
34
+ class StyleGAN2Generator(nn.Module):
35
+ """Defines the generator network in StyleGAN2.
36
+
37
+ NOTE: The synthesized images are with `RGB` channel order and pixel range
38
+ [-1, 1].
39
+
40
+ Settings for the mapping network:
41
+
42
+ (1) z_dim: Dimension of the input latent space, Z. (default: 512)
43
+ (2) w_dim: Dimension of the output latent space, W. (default: 512)
44
+ (3) repeat_w: Repeat w-code for different layers. (default: True)
45
+ (4) normalize_z: Whether to normalize the z-code. (default: True)
46
+ (5) mapping_layers: Number of layers of the mapping network. (default: 8)
47
+ (6) mapping_fmaps: Number of hidden channels of the mapping network.
48
+ (default: 512)
49
+ (7) mapping_use_wscale: Whether to use weight scaling for the mapping
50
+ network. (default: True)
51
+ (8) mapping_wscale_gain: The factor to control weight scaling for the
52
+ mapping network (default: 1.0)
53
+ (9) mapping_lr_mul: Learning rate multiplier for the mapping network.
54
+ (default: 0.01)
55
+
56
+ Settings for conditional generation:
57
+
58
+ (1) label_dim: Dimension of the additional label for conditional generation.
59
+ In one-hot conditioning case, it is equal to the number of classes. If
60
+ set to 0, conditioning training will be disabled. (default: 0)
61
+ (2) embedding_dim: Dimension of the embedding space, if needed.
62
+ (default: 512)
63
+ (3) embedding_bias: Whether to add bias to embedding learning.
64
+ (default: True)
65
+ (4) embedding_use_wscale: Whether to use weight scaling for embedding
66
+ learning. (default: True)
67
+ (5) embedding_wscale_gain: The factor to control weight scaling for
68
+ embedding. (default: 1.0)
69
+ (6) embedding_lr_mul: Learning rate multiplier for the embedding learning.
70
+ (default: 1.0)
71
+ (7) normalize_embedding: Whether to normalize the embedding. (default: True)
72
+ (8) normalize_embedding_latent: Whether to normalize the embedding together
73
+ with the latent. (default: False)
74
+
75
+ Settings for the synthesis network:
76
+
77
+ (1) resolution: The resolution of the output image. (default: -1)
78
+ (2) init_res: The initial resolution to start with convolution. (default: 4)
79
+ (3) image_channels: Number of channels of the output image. (default: 3)
80
+ (4) final_tanh: Whether to use `tanh` to control the final pixel range.
81
+ (default: False)
82
+ (5) const_input: Whether to use a constant in the first convolutional layer.
83
+ (default: True)
84
+ (6) architecture: Type of architecture. Support `origin`, `skip`, and
85
+ `resnet`. (default: `skip`)
86
+ (7) demodulate: Whether to perform style demodulation. (default: True)
87
+ (8) use_wscale: Whether to use weight scaling. (default: True)
88
+ (9) wscale_gain: The factor to control weight scaling. (default: 1.0)
89
+ (10) lr_mul: Learning rate multiplier for the synthesis network.
90
+ (default: 1.0)
91
+ (11) noise_type: Type of noise added to the convolutional results at each
92
+ layer. (default: `spatial`)
93
+ (12) fmaps_base: Factor to control number of feature maps for each layer.
94
+ (default: 32 << 10)
95
+ (13) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
96
+ (14) filter_kernel: Kernel used for filtering (e.g., downsampling).
97
+ (default: (1, 3, 3, 1))
98
+ (15) conv_clamp: A threshold to clamp the output of convolution layers to
99
+ avoid overflow under FP16 training. (default: None)
100
+ (16) eps: A small value to avoid divide overflow. (default: 1e-8)
101
+
102
+ Runtime settings:
103
+
104
+ (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for
105
+ training only. Set `None` to disable. (default: None)
106
+ (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set
107
+ as `True`, the stats will be more accurate, yet the speed maybe a little
108
+ bit slower. (default: False)
109
+ (3) style_mixing_prob: Probability to perform style mixing as a training
110
+ regularization. Set `None` to disable. (default: None)
111
+ (4) trunc_psi: Truncation psi, set `None` to disable. (default: None)
112
+ (5) trunc_layers: Number of layers to perform truncation. (default: None)
113
+ (6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`,
114
+ `const`. (default: `const`)
115
+ (7) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together.
116
+ (default: False)
117
+ (8) fp16_res: Layers at resolution higher than (or equal to) this field will
118
+ use `float16` precision for computation. This is merely used for
119
+ acceleration. If set as `None`, all layers will use `float32` by
120
+ default. (default: None)
121
+ (9) impl: Implementation mode of some particular ops, e.g., `filtering`,
122
+ `bias_act`, etc. `cuda` means using the official CUDA implementation
123
+ from StyleGAN2, while `ref` means using the native PyTorch ops.
124
+ (default: `cuda`)
125
+ """
126
+
127
+ def __init__(self,
128
+ # Settings for mapping network.
129
+ z_dim=512,
130
+ w_dim=512,
131
+ repeat_w=True,
132
+ normalize_z=True,
133
+ mapping_layers=8,
134
+ mapping_fmaps=512,
135
+ mapping_use_wscale=True,
136
+ mapping_wscale_gain=1.0,
137
+ mapping_lr_mul=0.01,
138
+ # Settings for conditional generation.
139
+ label_dim=0,
140
+ embedding_dim=512,
141
+ embedding_bias=True,
142
+ embedding_use_wscale=True,
143
+ embedding_wscale_gian=1.0,
144
+ embedding_lr_mul=1.0,
145
+ normalize_embedding=True,
146
+ normalize_embedding_latent=False,
147
+ # Settings for synthesis network.
148
+ resolution=-1,
149
+ init_res=4,
150
+ image_channels=3,
151
+ final_tanh=False,
152
+ const_input=True,
153
+ architecture='skip',
154
+ demodulate=True,
155
+ use_wscale=True,
156
+ wscale_gain=1.0,
157
+ lr_mul=1.0,
158
+ noise_type='spatial',
159
+ fmaps_base=32 << 10,
160
+ fmaps_max=512,
161
+ filter_kernel=(1, 3, 3, 1),
162
+ conv_clamp=None,
163
+ eps=1e-8):
164
+ """Initializes with basic settings.
165
+
166
+ Raises:
167
+ ValueError: If the `resolution` is not supported, or `architecture`
168
+ is not supported.
169
+ """
170
+ super().__init__()
171
+
172
+ if resolution not in _RESOLUTIONS_ALLOWED:
173
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
174
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
175
+ architecture = architecture.lower()
176
+ if architecture not in _ARCHITECTURES_ALLOWED:
177
+ raise ValueError(f'Invalid architecture: `{architecture}`!\n'
178
+ f'Architectures allowed: '
179
+ f'{_ARCHITECTURES_ALLOWED}.')
180
+
181
+ self.z_dim = z_dim
182
+ self.w_dim = w_dim
183
+ self.repeat_w = repeat_w
184
+ self.normalize_z = normalize_z
185
+ self.mapping_layers = mapping_layers
186
+ self.mapping_fmaps = mapping_fmaps
187
+ self.mapping_use_wscale = mapping_use_wscale
188
+ self.mapping_wscale_gain = mapping_wscale_gain
189
+ self.mapping_lr_mul = mapping_lr_mul
190
+
191
+ self.label_dim = label_dim
192
+ self.embedding_dim = embedding_dim
193
+ self.embedding_bias = embedding_bias
194
+ self.embedding_use_wscale = embedding_use_wscale
195
+ self.embedding_wscale_gain = embedding_wscale_gian
196
+ self.embedding_lr_mul = embedding_lr_mul
197
+ self.normalize_embedding = normalize_embedding
198
+ self.normalize_embedding_latent = normalize_embedding_latent
199
+
200
+ self.resolution = resolution
201
+ self.init_res = init_res
202
+ self.image_channels = image_channels
203
+ self.final_tanh = final_tanh
204
+ self.const_input = const_input
205
+ self.architecture = architecture
206
+ self.demodulate = demodulate
207
+ self.use_wscale = use_wscale
208
+ self.wscale_gain = wscale_gain
209
+ self.lr_mul = lr_mul
210
+ self.noise_type = noise_type.lower()
211
+ self.fmaps_base = fmaps_base
212
+ self.fmaps_max = fmaps_max
213
+ self.filter_kernel = filter_kernel
214
+ self.conv_clamp = conv_clamp
215
+ self.eps = eps
216
+
217
+ # Dimension of latent space, which is convenient for sampling.
218
+ self.latent_dim = (z_dim,)
219
+
220
+ # Number of synthesis (convolutional) layers.
221
+ self.num_layers = int(np.log2(resolution // init_res * 2)) * 2
222
+
223
+ self.mapping = MappingNetwork(
224
+ input_dim=z_dim,
225
+ output_dim=w_dim,
226
+ num_outputs=self.num_layers,
227
+ repeat_output=repeat_w,
228
+ normalize_input=normalize_z,
229
+ num_layers=mapping_layers,
230
+ hidden_dim=mapping_fmaps,
231
+ use_wscale=mapping_use_wscale,
232
+ wscale_gain=mapping_wscale_gain,
233
+ lr_mul=mapping_lr_mul,
234
+ label_dim=label_dim,
235
+ embedding_dim=embedding_dim,
236
+ embedding_bias=embedding_bias,
237
+ embedding_use_wscale=embedding_use_wscale,
238
+ embedding_wscale_gian=embedding_wscale_gian,
239
+ embedding_lr_mul=embedding_lr_mul,
240
+ normalize_embedding=normalize_embedding,
241
+ normalize_embedding_latent=normalize_embedding_latent,
242
+ eps=eps)
243
+
244
+ # This is used for truncation trick.
245
+ if self.repeat_w:
246
+ self.register_buffer('w_avg', torch.zeros(w_dim))
247
+ else:
248
+ self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim))
249
+
250
+ self.synthesis = SynthesisNetwork(resolution=resolution,
251
+ init_res=init_res,
252
+ w_dim=w_dim,
253
+ image_channels=image_channels,
254
+ final_tanh=final_tanh,
255
+ const_input=const_input,
256
+ architecture=architecture,
257
+ demodulate=demodulate,
258
+ use_wscale=use_wscale,
259
+ wscale_gain=wscale_gain,
260
+ lr_mul=lr_mul,
261
+ noise_type=noise_type,
262
+ fmaps_base=fmaps_base,
263
+ filter_kernel=filter_kernel,
264
+ fmaps_max=fmaps_max,
265
+ conv_clamp=conv_clamp,
266
+ eps=eps)
267
+
268
+ self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
269
+ for key, val in self.mapping.pth_to_tf_var_mapping.items():
270
+ self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
271
+ for key, val in self.synthesis.pth_to_tf_var_mapping.items():
272
+ self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
273
+
274
+ def set_space_of_latent(self, space_of_latent):
275
+ """Sets the space to which the latent code belong.
276
+
277
+ See `SynthesisNetwork` for more details.
278
+ """
279
+ self.synthesis.set_space_of_latent(space_of_latent)
280
+
281
+ def forward(self,
282
+ z,
283
+ label=None,
284
+ w_moving_decay=None,
285
+ sync_w_avg=False,
286
+ style_mixing_prob=None,
287
+ trunc_psi=None,
288
+ trunc_layers=None,
289
+ noise_mode='const',
290
+ fused_modulate=False,
291
+ fp16_res=None,
292
+ impl='cuda'):
293
+ """Connects mapping network and synthesis network.
294
+
295
+ This forward function will also update the average `w_code`, perform
296
+ style mixing as a training regularizer, and do truncation trick, which
297
+ is specially designed for inference.
298
+
299
+ Concretely, the truncation trick acts as follows:
300
+
301
+ For layers in range [0, truncation_layers), the truncated w-code is
302
+ computed as
303
+
304
+ w_new = w_avg + (w - w_avg) * truncation_psi
305
+
306
+ To disable truncation, please set
307
+
308
+ (1) truncation_psi = 1.0 (None) OR
309
+ (2) truncation_layers = 0 (None)
310
+ """
311
+
312
+ mapping_results = self.mapping(z, label, impl=impl)
313
+
314
+ w = mapping_results['w']
315
+ if self.training and w_moving_decay is not None:
316
+ if sync_w_avg:
317
+ batch_w_avg = all_gather(w.detach()).mean(dim=0)
318
+ else:
319
+ batch_w_avg = w.detach().mean(dim=0)
320
+ self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay))
321
+
322
+ wp = mapping_results.pop('wp')
323
+ if self.training and style_mixing_prob is not None:
324
+ if np.random.uniform() < style_mixing_prob:
325
+ new_z = torch.randn_like(z)
326
+ new_wp = self.mapping(new_z, label, impl=impl)['wp']
327
+ mixing_cutoff = np.random.randint(1, self.num_layers)
328
+ wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:]
329
+
330
+ if not self.training:
331
+ trunc_psi = 1.0 if trunc_psi is None else trunc_psi
332
+ trunc_layers = 0 if trunc_layers is None else trunc_layers
333
+ if trunc_psi < 1.0 and trunc_layers > 0:
334
+ w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers]
335
+ wp[:, :trunc_layers] = w_avg.lerp(
336
+ wp[:, :trunc_layers], trunc_psi)
337
+
338
+ synthesis_results = self.synthesis(wp,
339
+ noise_mode=noise_mode,
340
+ fused_modulate=fused_modulate,
341
+ impl=impl,
342
+ fp16_res=fp16_res)
343
+
344
+ return {**mapping_results, **synthesis_results}
345
+
346
+
347
+ class MappingNetwork(nn.Module):
348
+ """Implements the latent space mapping network.
349
+
350
+ Basically, this network executes several dense layers in sequence, and the
351
+ label embedding if needed.
352
+ """
353
+
354
+ def __init__(self,
355
+ input_dim,
356
+ output_dim,
357
+ num_outputs,
358
+ repeat_output,
359
+ normalize_input,
360
+ num_layers,
361
+ hidden_dim,
362
+ use_wscale,
363
+ wscale_gain,
364
+ lr_mul,
365
+ label_dim,
366
+ embedding_dim,
367
+ embedding_bias,
368
+ embedding_use_wscale,
369
+ embedding_wscale_gian,
370
+ embedding_lr_mul,
371
+ normalize_embedding,
372
+ normalize_embedding_latent,
373
+ eps):
374
+ super().__init__()
375
+
376
+ self.input_dim = input_dim
377
+ self.output_dim = output_dim
378
+ self.num_outputs = num_outputs
379
+ self.repeat_output = repeat_output
380
+ self.normalize_input = normalize_input
381
+ self.num_layers = num_layers
382
+ self.hidden_dim = hidden_dim
383
+ self.use_wscale = use_wscale
384
+ self.wscale_gain = wscale_gain
385
+ self.lr_mul = lr_mul
386
+ self.label_dim = label_dim
387
+ self.embedding_dim = embedding_dim
388
+ self.embedding_bias = embedding_bias
389
+ self.embedding_use_wscale = embedding_use_wscale
390
+ self.embedding_wscale_gian = embedding_wscale_gian
391
+ self.embedding_lr_mul = embedding_lr_mul
392
+ self.normalize_embedding = normalize_embedding
393
+ self.normalize_embedding_latent = normalize_embedding_latent
394
+ self.eps = eps
395
+
396
+ self.pth_to_tf_var_mapping = {}
397
+
398
+ self.norm = PixelNormLayer(dim=1, eps=eps)
399
+
400
+ if self.label_dim > 0:
401
+ input_dim = input_dim + embedding_dim
402
+ self.embedding = DenseLayer(in_channels=label_dim,
403
+ out_channels=embedding_dim,
404
+ add_bias=embedding_bias,
405
+ init_bias=0.0,
406
+ use_wscale=embedding_use_wscale,
407
+ wscale_gain=embedding_wscale_gian,
408
+ lr_mul=embedding_lr_mul,
409
+ activation_type='linear')
410
+ self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight'
411
+ if self.embedding_bias:
412
+ self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias'
413
+
414
+ if num_outputs is not None and not repeat_output:
415
+ output_dim = output_dim * num_outputs
416
+ for i in range(num_layers):
417
+ in_channels = (input_dim if i == 0 else hidden_dim)
418
+ out_channels = (output_dim if i == (num_layers - 1) else hidden_dim)
419
+ self.add_module(f'dense{i}',
420
+ DenseLayer(in_channels=in_channels,
421
+ out_channels=out_channels,
422
+ add_bias=True,
423
+ init_bias=0.0,
424
+ use_wscale=use_wscale,
425
+ wscale_gain=wscale_gain,
426
+ lr_mul=lr_mul,
427
+ activation_type='lrelu'))
428
+ self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
429
+ self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
430
+
431
+ def forward(self, z, label=None, impl='cuda'):
432
+ if z.ndim != 2 or z.shape[1] != self.input_dim:
433
+ raise ValueError(f'Input latent code should be with shape '
434
+ f'[batch_size, input_dim], where '
435
+ f'`input_dim` equals to {self.input_dim}!\n'
436
+ f'But `{z.shape}` is received!')
437
+ if self.normalize_input:
438
+ z = self.norm(z)
439
+
440
+ if self.label_dim > 0:
441
+ if label is None:
442
+ raise ValueError(f'Model requires an additional label '
443
+ f'(with dimension {self.label_dim}) as input, '
444
+ f'but no label is received!')
445
+ if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
446
+ raise ValueError(f'Input label should be with shape '
447
+ f'[batch_size, label_dim], where '
448
+ f'`batch_size` equals to that of '
449
+ f'latent codes ({z.shape[0]}) and '
450
+ f'`label_dim` equals to {self.label_dim}!\n'
451
+ f'But `{label.shape}` is received!')
452
+ label = label.to(dtype=torch.float32)
453
+ embedding = self.embedding(label, impl=impl)
454
+ if self.normalize_embedding:
455
+ embedding = self.norm(embedding)
456
+ w = torch.cat((z, embedding), dim=1)
457
+ else:
458
+ w = z
459
+
460
+ if self.label_dim > 0 and self.normalize_embedding_latent:
461
+ w = self.norm(w)
462
+
463
+ for i in range(self.num_layers):
464
+ w = getattr(self, f'dense{i}')(w, impl=impl)
465
+
466
+ wp = None
467
+ if self.num_outputs is not None:
468
+ if self.repeat_output:
469
+ wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1))
470
+ else:
471
+ wp = w.reshape(-1, self.num_outputs, self.output_dim)
472
+
473
+ results = {
474
+ 'z': z,
475
+ 'label': label,
476
+ 'w': w,
477
+ 'wp': wp,
478
+ }
479
+ if self.label_dim > 0:
480
+ results['embedding'] = embedding
481
+ return results
482
+
483
+
484
+ class SynthesisNetwork(nn.Module):
485
+ """Implements the image synthesis network.
486
+
487
+ Basically, this network executes several convolutional layers in sequence.
488
+ """
489
+
490
+ def __init__(self,
491
+ resolution,
492
+ init_res,
493
+ w_dim,
494
+ image_channels,
495
+ final_tanh,
496
+ const_input,
497
+ architecture,
498
+ demodulate,
499
+ use_wscale,
500
+ wscale_gain,
501
+ lr_mul,
502
+ noise_type,
503
+ fmaps_base,
504
+ fmaps_max,
505
+ filter_kernel,
506
+ conv_clamp,
507
+ eps):
508
+ super().__init__()
509
+
510
+ self.init_res = init_res
511
+ self.init_res_log2 = int(np.log2(init_res))
512
+ self.resolution = resolution
513
+ self.final_res_log2 = int(np.log2(resolution))
514
+ self.w_dim = w_dim
515
+ self.image_channels = image_channels
516
+ self.final_tanh = final_tanh
517
+ self.const_input = const_input
518
+ self.architecture = architecture.lower()
519
+ self.demodulate = demodulate
520
+ self.use_wscale = use_wscale
521
+ self.wscale_gain = wscale_gain
522
+ self.lr_mul = lr_mul
523
+ self.noise_type = noise_type.lower()
524
+ self.fmaps_base = fmaps_base
525
+ self.fmaps_max = fmaps_max
526
+ self.filter_kernel = filter_kernel
527
+ self.conv_clamp = conv_clamp
528
+ self.eps = eps
529
+
530
+ self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
531
+
532
+ self.pth_to_tf_var_mapping = {}
533
+
534
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
535
+ res = 2 ** res_log2
536
+ in_channels = self.get_nf(res // 2)
537
+ out_channels = self.get_nf(res)
538
+ block_idx = res_log2 - self.init_res_log2
539
+
540
+ # Early layer.
541
+ if res == init_res:
542
+ if self.const_input:
543
+ self.add_module('early_layer',
544
+ InputLayer(init_res=res,
545
+ channels=out_channels))
546
+ self.pth_to_tf_var_mapping['early_layer.const'] = (
547
+ f'{res}x{res}/Const/const')
548
+ else:
549
+ channels = out_channels * res * res
550
+ self.add_module('early_layer',
551
+ DenseLayer(in_channels=w_dim,
552
+ out_channels=channels,
553
+ add_bias=True,
554
+ init_bias=0.0,
555
+ use_wscale=use_wscale,
556
+ wscale_gain=wscale_gain,
557
+ lr_mul=lr_mul,
558
+ activation_type='lrelu'))
559
+ self.pth_to_tf_var_mapping['early_layer.weight'] = (
560
+ f'{res}x{res}/Dense/weight')
561
+ self.pth_to_tf_var_mapping['early_layer.bias'] = (
562
+ f'{res}x{res}/Dense/bias')
563
+ else:
564
+ # Residual branch (kernel 1x1) with upsampling, without bias,
565
+ # with linear activation.
566
+ if self.architecture == 'resnet':
567
+ layer_name = f'residual{block_idx}'
568
+ self.add_module(layer_name,
569
+ ConvLayer(in_channels=in_channels,
570
+ out_channels=out_channels,
571
+ kernel_size=1,
572
+ add_bias=False,
573
+ scale_factor=2,
574
+ filter_kernel=filter_kernel,
575
+ use_wscale=use_wscale,
576
+ wscale_gain=wscale_gain,
577
+ lr_mul=lr_mul,
578
+ activation_type='linear',
579
+ conv_clamp=None))
580
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
581
+ f'{res}x{res}/Skip/weight')
582
+
583
+ # First layer (kernel 3x3) with upsampling.
584
+ layer_name = f'layer{2 * block_idx - 1}'
585
+ self.add_module(layer_name,
586
+ ModulateConvLayer(in_channels=in_channels,
587
+ out_channels=out_channels,
588
+ resolution=res,
589
+ w_dim=w_dim,
590
+ kernel_size=3,
591
+ add_bias=True,
592
+ scale_factor=2,
593
+ filter_kernel=filter_kernel,
594
+ demodulate=demodulate,
595
+ use_wscale=use_wscale,
596
+ wscale_gain=wscale_gain,
597
+ lr_mul=lr_mul,
598
+ noise_type=noise_type,
599
+ activation_type='lrelu',
600
+ conv_clamp=conv_clamp,
601
+ eps=eps))
602
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
603
+ f'{res}x{res}/Conv0_up/weight')
604
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
605
+ f'{res}x{res}/Conv0_up/bias')
606
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
607
+ f'{res}x{res}/Conv0_up/mod_weight')
608
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
609
+ f'{res}x{res}/Conv0_up/mod_bias')
610
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
611
+ f'{res}x{res}/Conv0_up/noise_strength')
612
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
613
+ f'noise{2 * block_idx - 1}')
614
+
615
+ # Second layer (kernel 3x3) without upsampling.
616
+ layer_name = f'layer{2 * block_idx}'
617
+ self.add_module(layer_name,
618
+ ModulateConvLayer(in_channels=out_channels,
619
+ out_channels=out_channels,
620
+ resolution=res,
621
+ w_dim=w_dim,
622
+ kernel_size=3,
623
+ add_bias=True,
624
+ scale_factor=1,
625
+ filter_kernel=None,
626
+ demodulate=demodulate,
627
+ use_wscale=use_wscale,
628
+ wscale_gain=wscale_gain,
629
+ lr_mul=lr_mul,
630
+ noise_type=noise_type,
631
+ activation_type='lrelu',
632
+ conv_clamp=conv_clamp,
633
+ eps=eps))
634
+ tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
635
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
636
+ f'{res}x{res}/{tf_layer_name}/weight')
637
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
638
+ f'{res}x{res}/{tf_layer_name}/bias')
639
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
640
+ f'{res}x{res}/{tf_layer_name}/mod_weight')
641
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
642
+ f'{res}x{res}/{tf_layer_name}/mod_bias')
643
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
644
+ f'{res}x{res}/{tf_layer_name}/noise_strength')
645
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
646
+ f'noise{2 * block_idx}')
647
+
648
+ # Output convolution layer for each resolution (if needed).
649
+ if res_log2 == self.final_res_log2 or self.architecture == 'skip':
650
+ layer_name = f'output{block_idx}'
651
+ self.add_module(layer_name,
652
+ ModulateConvLayer(in_channels=out_channels,
653
+ out_channels=image_channels,
654
+ resolution=res,
655
+ w_dim=w_dim,
656
+ kernel_size=1,
657
+ add_bias=True,
658
+ scale_factor=1,
659
+ filter_kernel=None,
660
+ demodulate=False,
661
+ use_wscale=use_wscale,
662
+ wscale_gain=wscale_gain,
663
+ lr_mul=lr_mul,
664
+ noise_type='none',
665
+ activation_type='linear',
666
+ conv_clamp=conv_clamp,
667
+ eps=eps))
668
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
669
+ f'{res}x{res}/ToRGB/weight')
670
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
671
+ f'{res}x{res}/ToRGB/bias')
672
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
673
+ f'{res}x{res}/ToRGB/mod_weight')
674
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
675
+ f'{res}x{res}/ToRGB/mod_bias')
676
+
677
+ # Used for upsampling output images for each resolution block for sum.
678
+ if self.architecture == 'skip':
679
+ self.register_buffer(
680
+ 'filter', upfirdn2d.setup_filter(filter_kernel))
681
+
682
+ def get_nf(self, res):
683
+ """Gets number of feature maps according to the given resolution."""
684
+ return min(self.fmaps_base // res, self.fmaps_max)
685
+
686
+ def set_space_of_latent(self, space_of_latent):
687
+ """Sets the space to which the latent code belong.
688
+
689
+ This function is particularly used for choosing how to inject the latent
690
+ code into the convolutional layers. The original generator will take a
691
+ W-Space code and apply it for style modulation after an affine
692
+ transformation. But, sometimes, it may need to directly feed an already
693
+ affine-transformed code into the convolutional layer, e.g., when
694
+ training an encoder for GAN inversion. We term the transformed space as
695
+ Style Space (or Y-Space). This function is designed to tell the
696
+ convolutional layers how to use the input code.
697
+
698
+ Args:
699
+ space_of_latent: The space to which the latent code belong. Case
700
+ insensitive. Support `W` and `Y`.
701
+ """
702
+ space_of_latent = space_of_latent.upper()
703
+ for module in self.modules():
704
+ if isinstance(module, ModulateConvLayer):
705
+ setattr(module, 'space_of_latent', space_of_latent)
706
+
707
+ def forward(self,
708
+ wp,
709
+ noise_mode='const',
710
+ fused_modulate=False,
711
+ fp16_res=None,
712
+ impl='cuda'):
713
+ results = {'wp': wp}
714
+
715
+ if self.const_input:
716
+ x = self.early_layer(wp[:, 0])
717
+ else:
718
+ x = self.early_layer(wp[:, 0], impl=impl)
719
+
720
+ # Cast to `torch.float16` if needed.
721
+ if fp16_res is not None and self.init_res >= fp16_res:
722
+ x = x.to(torch.float16)
723
+
724
+ if self.architecture == 'origin':
725
+ for layer_idx in range(self.num_layers - 1):
726
+ layer = getattr(self, f'layer{layer_idx}')
727
+ x, style = layer(x,
728
+ wp[:, layer_idx],
729
+ noise_mode=noise_mode,
730
+ fused_modulate=fused_modulate,
731
+ impl=impl)
732
+ results[f'style{layer_idx}'] = style
733
+
734
+ # Cast to `torch.float16` if needed.
735
+ if layer_idx % 2 == 0 and layer_idx != self.num_layers - 2:
736
+ res = self.init_res * (2 ** (layer_idx // 2))
737
+ if fp16_res is not None and res * 2 >= fp16_res:
738
+ x = x.to(torch.float16)
739
+ else:
740
+ x = x.to(torch.float32)
741
+ output_layer = getattr(self, f'output{layer_idx // 2}')
742
+ image, style = output_layer(x,
743
+ wp[:, layer_idx + 1],
744
+ fused_modulate=fused_modulate,
745
+ impl=impl)
746
+ image = image.to(torch.float32)
747
+ results[f'output_style{layer_idx // 2}'] = style
748
+
749
+ elif self.architecture == 'skip':
750
+ for layer_idx in range(self.num_layers - 1):
751
+ layer = getattr(self, f'layer{layer_idx}')
752
+ x, style = layer(x,
753
+ wp[:, layer_idx],
754
+ noise_mode=noise_mode,
755
+ fused_modulate=fused_modulate,
756
+ impl=impl)
757
+ results[f'style{layer_idx}'] = style
758
+ if layer_idx % 2 == 0:
759
+ output_layer = getattr(self, f'output{layer_idx // 2}')
760
+ y, style = output_layer(x,
761
+ wp[:, layer_idx + 1],
762
+ fused_modulate=fused_modulate,
763
+ impl=impl)
764
+ results[f'output_style{layer_idx // 2}'] = style
765
+ if layer_idx == 0:
766
+ image = y.to(torch.float32)
767
+ else:
768
+ image = y.to(torch.float32) + upfirdn2d.upsample2d(
769
+ image, self.filter, impl=impl)
770
+
771
+ # Cast to `torch.float16` if needed.
772
+ if layer_idx != self.num_layers - 2:
773
+ res = self.init_res * (2 ** (layer_idx // 2))
774
+ if fp16_res is not None and res * 2 >= fp16_res:
775
+ x = x.to(torch.float16)
776
+ else:
777
+ x = x.to(torch.float32)
778
+
779
+ elif self.architecture == 'resnet':
780
+ x, style = self.layer0(x,
781
+ wp[:, 0],
782
+ noise_mode=noise_mode,
783
+ fused_modulate=fused_modulate,
784
+ impl=impl)
785
+ results['style0'] = style
786
+ for layer_idx in range(1, self.num_layers - 1, 2):
787
+ # Cast to `torch.float16` if needed.
788
+ if layer_idx % 2 == 1:
789
+ res = self.init_res * (2 ** (layer_idx // 2))
790
+ if fp16_res is not None and res * 2 >= fp16_res:
791
+ x = x.to(torch.float16)
792
+ else:
793
+ x = x.to(torch.float32)
794
+
795
+ skip_layer = getattr(self, f'residual{layer_idx // 2 + 1}')
796
+ residual = skip_layer(x, runtime_gain=np.sqrt(0.5), impl=impl)
797
+ layer = getattr(self, f'layer{layer_idx}')
798
+ x, style = layer(x,
799
+ wp[:, layer_idx],
800
+ noise_mode=noise_mode,
801
+ fused_modulate=fused_modulate,
802
+ impl=impl)
803
+ results[f'style{layer_idx}'] = style
804
+ layer = getattr(self, f'layer{layer_idx + 1}')
805
+ x, style = layer(x,
806
+ wp[:, layer_idx + 1],
807
+ runtime_gain=np.sqrt(0.5),
808
+ noise_mode=noise_mode,
809
+ fused_modulate=fused_modulate,
810
+ impl=impl)
811
+ results[f'style{layer_idx + 1}'] = style
812
+ x = x + residual
813
+ output_layer = getattr(self, f'output{layer_idx // 2 + 1}')
814
+ image, style = output_layer(x,
815
+ wp[:, layer_idx + 2],
816
+ fused_modulate=fused_modulate,
817
+ impl=impl)
818
+ image = image.to(torch.float32)
819
+ results[f'output_style{layer_idx // 2}'] = style
820
+
821
+ if self.final_tanh:
822
+ image = torch.tanh(image)
823
+ results['image'] = image
824
+ return results
825
+
826
+
827
+ class PixelNormLayer(nn.Module):
828
+ """Implements pixel-wise feature vector normalization layer."""
829
+
830
+ def __init__(self, dim, eps):
831
+ super().__init__()
832
+ self.dim = dim
833
+ self.eps = eps
834
+
835
+ def extra_repr(self):
836
+ return f'dim={self.dim}, epsilon={self.eps}'
837
+
838
+ def forward(self, x):
839
+ scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
840
+ return x * scale
841
+
842
+
843
+ class InputLayer(nn.Module):
844
+ """Implements the input layer to start convolution with.
845
+
846
+ Basically, this block starts from a const input, which is with shape
847
+ `(channels, init_res, init_res)`.
848
+ """
849
+
850
+ def __init__(self, init_res, channels):
851
+ super().__init__()
852
+ self.const = nn.Parameter(torch.randn(1, channels, init_res, init_res))
853
+
854
+ def forward(self, w):
855
+ x = self.const.repeat(w.shape[0], 1, 1, 1)
856
+ return x
857
+
858
+
859
+ class ConvLayer(nn.Module):
860
+ """Implements the convolutional layer.
861
+
862
+ If upsampling is needed (i.e., `scale_factor = 2`), the feature map will
863
+ be filtered with `filter_kernel` after convolution. This layer will only be
864
+ used for skip connection in `resnet` architecture.
865
+ """
866
+
867
+ def __init__(self,
868
+ in_channels,
869
+ out_channels,
870
+ kernel_size,
871
+ add_bias,
872
+ scale_factor,
873
+ filter_kernel,
874
+ use_wscale,
875
+ wscale_gain,
876
+ lr_mul,
877
+ activation_type,
878
+ conv_clamp):
879
+ """Initializes with layer settings.
880
+
881
+ Args:
882
+ in_channels: Number of channels of the input tensor.
883
+ out_channels: Number of channels of the output tensor.
884
+ kernel_size: Size of the convolutional kernels.
885
+ add_bias: Whether to add bias onto the convolutional result.
886
+ scale_factor: Scale factor for upsampling.
887
+ filter_kernel: Kernel used for filtering.
888
+ use_wscale: Whether to use weight scaling.
889
+ wscale_gain: Gain factor for weight scaling.
890
+ lr_mul: Learning multiplier for both weight and bias.
891
+ activation_type: Type of activation.
892
+ conv_clamp: A threshold to clamp the output of convolution layers to
893
+ avoid overflow under FP16 training.
894
+ """
895
+ super().__init__()
896
+ self.in_channels = in_channels
897
+ self.out_channels = out_channels
898
+ self.kernel_size = kernel_size
899
+ self.add_bias = add_bias
900
+ self.scale_factor = scale_factor
901
+ self.filter_kernel = filter_kernel
902
+ self.use_wscale = use_wscale
903
+ self.wscale_gain = wscale_gain
904
+ self.lr_mul = lr_mul
905
+ self.activation_type = activation_type
906
+ self.conv_clamp = conv_clamp
907
+
908
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
909
+ fan_in = kernel_size * kernel_size * in_channels
910
+ wscale = wscale_gain / np.sqrt(fan_in)
911
+ if use_wscale:
912
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
913
+ self.wscale = wscale * lr_mul
914
+ else:
915
+ self.weight = nn.Parameter(
916
+ torch.randn(*weight_shape) * wscale / lr_mul)
917
+ self.wscale = lr_mul
918
+
919
+ if add_bias:
920
+ self.bias = nn.Parameter(torch.zeros(out_channels))
921
+ self.bscale = lr_mul
922
+ else:
923
+ self.bias = None
924
+ self.act_gain = bias_act.activation_funcs[activation_type].def_gain
925
+
926
+ if scale_factor > 1:
927
+ assert filter_kernel is not None
928
+ self.register_buffer(
929
+ 'filter', upfirdn2d.setup_filter(filter_kernel))
930
+ fh, fw = self.filter.shape
931
+ self.filter_padding = (
932
+ kernel_size // 2 + (fw + scale_factor - 1) // 2,
933
+ kernel_size // 2 + (fw - scale_factor) // 2,
934
+ kernel_size // 2 + (fh + scale_factor - 1) // 2,
935
+ kernel_size // 2 + (fh - scale_factor) // 2)
936
+
937
+ def extra_repr(self):
938
+ return (f'in_ch={self.in_channels}, '
939
+ f'out_ch={self.out_channels}, '
940
+ f'ksize={self.kernel_size}, '
941
+ f'wscale_gain={self.wscale_gain:.3f}, '
942
+ f'bias={self.add_bias}, '
943
+ f'lr_mul={self.lr_mul:.3f}, '
944
+ f'upsample={self.scale_factor}, '
945
+ f'upsample_filter={self.filter_kernel}, '
946
+ f'act={self.activation_type}, '
947
+ f'clamp={self.conv_clamp}')
948
+
949
+ def forward(self, x, runtime_gain=1.0, impl='cuda'):
950
+ dtype = x.dtype
951
+
952
+ weight = self.weight
953
+ if self.wscale != 1.0:
954
+ weight = weight * self.wscale
955
+ bias = None
956
+ if self.bias is not None:
957
+ bias = self.bias.to(dtype)
958
+ if self.bscale != 1.0:
959
+ bias = bias * self.bscale
960
+
961
+ if self.scale_factor == 1: # Native convolution without upsampling.
962
+ padding = self.kernel_size // 2
963
+ x = conv2d_gradfix.conv2d(
964
+ x, weight.to(dtype), stride=1, padding=padding, impl=impl)
965
+ else: # Convolution with upsampling.
966
+ up = self.scale_factor
967
+ f = self.filter
968
+ # When kernel size = 1, use filtering function for upsampling.
969
+ if self.kernel_size == 1:
970
+ padding = self.filter_padding
971
+ x = conv2d_gradfix.conv2d(
972
+ x, weight.to(dtype), stride=1, padding=0, impl=impl)
973
+ x = upfirdn2d.upfirdn2d(
974
+ x, f, up=up, padding=padding, gain=up ** 2, impl=impl)
975
+ # When kernel size != 1, use transpose convolution for upsampling.
976
+ else:
977
+ # Following codes are borrowed from
978
+ # https://github.com/NVlabs/stylegan2-ada-pytorch
979
+ px0, px1, py0, py1 = self.filter_padding
980
+ kh, kw = weight.shape[2:]
981
+ px0 = px0 - (kw - 1)
982
+ px1 = px1 - (kw - up)
983
+ py0 = py0 - (kh - 1)
984
+ py1 = py1 - (kh - up)
985
+ pxt = max(min(-px0, -px1), 0)
986
+ pyt = max(min(-py0, -py1), 0)
987
+ weight = weight.transpose(0, 1)
988
+ padding = (pyt, pxt)
989
+ x = conv2d_gradfix.conv_transpose2d(
990
+ x, weight.to(dtype), stride=up, padding=padding, impl=impl)
991
+ padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt)
992
+ x = upfirdn2d.upfirdn2d(
993
+ x, f, up=1, padding=padding, gain=up ** 2, impl=impl)
994
+
995
+ act_gain = self.act_gain * runtime_gain
996
+ act_clamp = None
997
+ if self.conv_clamp is not None:
998
+ act_clamp = self.conv_clamp * runtime_gain
999
+ x = bias_act.bias_act(x, bias,
1000
+ act=self.activation_type,
1001
+ gain=act_gain,
1002
+ clamp=act_clamp,
1003
+ impl=impl)
1004
+
1005
+ assert x.dtype == dtype
1006
+ return x
1007
+
1008
+
1009
+ class ModulateConvLayer(nn.Module):
1010
+ """Implements the convolutional layer with style modulation."""
1011
+
1012
+ def __init__(self,
1013
+ in_channels,
1014
+ out_channels,
1015
+ resolution,
1016
+ w_dim,
1017
+ kernel_size,
1018
+ add_bias,
1019
+ scale_factor,
1020
+ filter_kernel,
1021
+ demodulate,
1022
+ use_wscale,
1023
+ wscale_gain,
1024
+ lr_mul,
1025
+ noise_type,
1026
+ activation_type,
1027
+ conv_clamp,
1028
+ eps):
1029
+ """Initializes with layer settings.
1030
+
1031
+ Args:
1032
+ in_channels: Number of channels of the input tensor.
1033
+ out_channels: Number of channels of the output tensor.
1034
+ resolution: Resolution of the output tensor.
1035
+ w_dim: Dimension of W space for style modulation.
1036
+ kernel_size: Size of the convolutional kernels.
1037
+ add_bias: Whether to add bias onto the convolutional result.
1038
+ scale_factor: Scale factor for upsampling.
1039
+ filter_kernel: Kernel used for filtering.
1040
+ demodulate: Whether to perform style demodulation.
1041
+ use_wscale: Whether to use weight scaling.
1042
+ wscale_gain: Gain factor for weight scaling.
1043
+ lr_mul: Learning multiplier for both weight and bias.
1044
+ noise_type: Type of noise added to the feature map after the
1045
+ convolution (if needed). Support `none`, `spatial` and
1046
+ `channel`.
1047
+ activation_type: Type of activation.
1048
+ conv_clamp: A threshold to clamp the output of convolution layers to
1049
+ avoid overflow under FP16 training.
1050
+ eps: A small value to avoid divide overflow.
1051
+ """
1052
+ super().__init__()
1053
+
1054
+ self.in_channels = in_channels
1055
+ self.out_channels = out_channels
1056
+ self.resolution = resolution
1057
+ self.w_dim = w_dim
1058
+ self.kernel_size = kernel_size
1059
+ self.add_bias = add_bias
1060
+ self.scale_factor = scale_factor
1061
+ self.filter_kernel = filter_kernel
1062
+ self.demodulate = demodulate
1063
+ self.use_wscale = use_wscale
1064
+ self.wscale_gain = wscale_gain
1065
+ self.lr_mul = lr_mul
1066
+ self.noise_type = noise_type.lower()
1067
+ self.activation_type = activation_type
1068
+ self.conv_clamp = conv_clamp
1069
+ self.eps = eps
1070
+
1071
+ self.space_of_latent = 'W'
1072
+
1073
+ # Set up weight.
1074
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
1075
+ fan_in = kernel_size * kernel_size * in_channels
1076
+ wscale = wscale_gain / np.sqrt(fan_in)
1077
+ if use_wscale:
1078
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
1079
+ self.wscale = wscale * lr_mul
1080
+ else:
1081
+ self.weight = nn.Parameter(
1082
+ torch.randn(*weight_shape) * wscale / lr_mul)
1083
+ self.wscale = lr_mul
1084
+
1085
+ # Set up bias.
1086
+ if add_bias:
1087
+ self.bias = nn.Parameter(torch.zeros(out_channels))
1088
+ self.bscale = lr_mul
1089
+ else:
1090
+ self.bias = None
1091
+ self.act_gain = bias_act.activation_funcs[activation_type].def_gain
1092
+
1093
+ # Set up style.
1094
+ self.style = DenseLayer(in_channels=w_dim,
1095
+ out_channels=in_channels,
1096
+ add_bias=True,
1097
+ init_bias=1.0,
1098
+ use_wscale=use_wscale,
1099
+ wscale_gain=wscale_gain,
1100
+ lr_mul=lr_mul,
1101
+ activation_type='linear')
1102
+
1103
+ # Set up noise.
1104
+ if self.noise_type != 'none':
1105
+ self.noise_strength = nn.Parameter(torch.zeros(()))
1106
+ if self.noise_type == 'spatial':
1107
+ self.register_buffer(
1108
+ 'noise', torch.randn(1, 1, resolution, resolution))
1109
+ elif self.noise_type == 'channel':
1110
+ self.register_buffer(
1111
+ 'noise', torch.randn(1, out_channels, 1, 1))
1112
+ else:
1113
+ raise NotImplementedError(f'Not implemented noise type: '
1114
+ f'`{self.noise_type}`!')
1115
+
1116
+ if scale_factor > 1:
1117
+ assert filter_kernel is not None
1118
+ self.register_buffer(
1119
+ 'filter', upfirdn2d.setup_filter(filter_kernel))
1120
+ fh, fw = self.filter.shape
1121
+ self.filter_padding = (
1122
+ kernel_size // 2 + (fw + scale_factor - 1) // 2,
1123
+ kernel_size // 2 + (fw - scale_factor) // 2,
1124
+ kernel_size // 2 + (fh + scale_factor - 1) // 2,
1125
+ kernel_size // 2 + (fh - scale_factor) // 2)
1126
+
1127
+ def extra_repr(self):
1128
+ return (f'in_ch={self.in_channels}, '
1129
+ f'out_ch={self.out_channels}, '
1130
+ f'ksize={self.kernel_size}, '
1131
+ f'wscale_gain={self.wscale_gain:.3f}, '
1132
+ f'bias={self.add_bias}, '
1133
+ f'lr_mul={self.lr_mul:.3f}, '
1134
+ f'upsample={self.scale_factor}, '
1135
+ f'upsample_filter={self.filter_kernel}, '
1136
+ f'demodulate={self.demodulate}, '
1137
+ f'noise_type={self.noise_type}, '
1138
+ f'act={self.activation_type}, '
1139
+ f'clamp={self.conv_clamp}')
1140
+
1141
+ def forward_style(self, w, impl='cuda'):
1142
+ """Gets style code from the given input.
1143
+
1144
+ More specifically, if the input is from W-Space, it will be projected by
1145
+ an affine transformation. If it is from the Style Space (Y-Space), no
1146
+ operation is required.
1147
+
1148
+ NOTE: For codes from Y-Space, we use slicing to make sure the dimension
1149
+ is correct, in case that the code is padded before fed into this layer.
1150
+ """
1151
+ space_of_latent = self.space_of_latent.upper()
1152
+ if space_of_latent == 'W':
1153
+ if w.ndim != 2 or w.shape[1] != self.w_dim:
1154
+ raise ValueError(f'The input tensor should be with shape '
1155
+ f'[batch_size, w_dim], where '
1156
+ f'`w_dim` equals to {self.w_dim}!\n'
1157
+ f'But `{w.shape}` is received!')
1158
+ style = self.style(w, impl=impl)
1159
+ elif space_of_latent == 'Y':
1160
+ if w.ndim != 2 or w.shape[1] < self.in_channels:
1161
+ raise ValueError(f'The input tensor should be with shape '
1162
+ f'[batch_size, y_dim], where '
1163
+ f'`y_dim` equals to {self.in_channels}!\n'
1164
+ f'But `{w.shape}` is received!')
1165
+ style = w[:, :self.in_channels]
1166
+ else:
1167
+ raise NotImplementedError(f'Not implemented `space_of_latent`: '
1168
+ f'`{space_of_latent}`!')
1169
+ return style
1170
+
1171
+ def forward(self,
1172
+ x,
1173
+ w,
1174
+ runtime_gain=1.0,
1175
+ noise_mode='const',
1176
+ fused_modulate=False,
1177
+ impl='cuda'):
1178
+ dtype = x.dtype
1179
+ N, C, H, W = x.shape
1180
+
1181
+ fused_modulate = (fused_modulate and
1182
+ not self.training and
1183
+ (dtype == torch.float32 or N == 1))
1184
+
1185
+ weight = self.weight
1186
+ out_ch, in_ch, kh, kw = weight.shape
1187
+ assert in_ch == C
1188
+
1189
+ # Affine on `w`.
1190
+ style = self.forward_style(w, impl=impl)
1191
+ if not self.demodulate:
1192
+ _style = style * self.wscale # Equivalent to scaling weight.
1193
+ else:
1194
+ _style = style
1195
+
1196
+ # Prepare noise.
1197
+ noise = None
1198
+ noise_mode = noise_mode.lower()
1199
+ if self.noise_type != 'none' and noise_mode != 'none':
1200
+ if noise_mode == 'random':
1201
+ noise = torch.randn((N, *self.noise.shape[1:]), device=x.device)
1202
+ elif noise_mode == 'const':
1203
+ noise = self.noise
1204
+ else:
1205
+ raise ValueError(f'Unknown noise mode `{noise_mode}`!')
1206
+ noise = (noise * self.noise_strength).to(dtype)
1207
+
1208
+ # Pre-normalize inputs to avoid FP16 overflow.
1209
+ if dtype == torch.float16 and self.demodulate:
1210
+ weight_max = weight.norm(float('inf'), dim=(1, 2, 3), keepdim=True)
1211
+ weight = weight * (self.wscale / weight_max)
1212
+ style_max = _style.norm(float('inf'), dim=1, keepdim=True)
1213
+ _style = _style / style_max
1214
+
1215
+ if self.demodulate or fused_modulate:
1216
+ _weight = weight.unsqueeze(0)
1217
+ _weight = _weight * _style.reshape(N, 1, in_ch, 1, 1)
1218
+ if self.demodulate:
1219
+ decoef = (_weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt()
1220
+ if self.demodulate and fused_modulate:
1221
+ _weight = _weight * decoef.reshape(N, out_ch, 1, 1, 1)
1222
+
1223
+ if not fused_modulate:
1224
+ x = x * _style.to(dtype).reshape(N, in_ch, 1, 1)
1225
+ w = weight.to(dtype)
1226
+ groups = 1
1227
+ else: # Use group convolution to fuse style modulation and convolution.
1228
+ x = x.reshape(1, N * in_ch, H, W)
1229
+ w = _weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype)
1230
+ groups = N
1231
+
1232
+ if self.scale_factor == 1: # Native convolution without upsampling.
1233
+ up = 1
1234
+ padding = self.kernel_size // 2
1235
+ x = conv2d_gradfix.conv2d(
1236
+ x, w, stride=1, padding=padding, groups=groups, impl=impl)
1237
+ else: # Convolution with upsampling.
1238
+ up = self.scale_factor
1239
+ f = self.filter
1240
+ # When kernel size = 1, use filtering function for upsampling.
1241
+ if self.kernel_size == 1:
1242
+ padding = self.filter_padding
1243
+ x = conv2d_gradfix.conv2d(
1244
+ x, w, stride=1, padding=0, groups=groups, impl=impl)
1245
+ x = upfirdn2d.upfirdn2d(
1246
+ x, f, up=up, padding=padding, gain=up ** 2, impl=impl)
1247
+ # When kernel size != 1, use stride convolution for upsampling.
1248
+ else:
1249
+ # Following codes are borrowed from
1250
+ # https://github.com/NVlabs/stylegan2-ada-pytorch
1251
+ px0, px1, py0, py1 = self.filter_padding
1252
+ px0 = px0 - (kw - 1)
1253
+ px1 = px1 - (kw - up)
1254
+ py0 = py0 - (kh - 1)
1255
+ py1 = py1 - (kh - up)
1256
+ pxt = max(min(-px0, -px1), 0)
1257
+ pyt = max(min(-py0, -py1), 0)
1258
+ if groups == 1:
1259
+ w = w.transpose(0, 1)
1260
+ else:
1261
+ w = w.reshape(N, out_ch, in_ch, kh, kw)
1262
+ w = w.transpose(1, 2)
1263
+ w = w.reshape(N * in_ch, out_ch, kh, kw)
1264
+ padding = (pyt, pxt)
1265
+ x = conv2d_gradfix.conv_transpose2d(
1266
+ x, w, stride=up, padding=padding, groups=groups, impl=impl)
1267
+ padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt)
1268
+ x = upfirdn2d.upfirdn2d(
1269
+ x, f, up=1, padding=padding, gain=up ** 2, impl=impl)
1270
+
1271
+ if not fused_modulate:
1272
+ if self.demodulate:
1273
+ decoef = decoef.to(dtype).reshape(N, out_ch, 1, 1)
1274
+ if self.demodulate and noise is not None:
1275
+ x = fma.fma(x, decoef, noise, impl=impl)
1276
+ else:
1277
+ if self.demodulate:
1278
+ x = x * decoef
1279
+ if noise is not None:
1280
+ x = x + noise
1281
+ else:
1282
+ x = x.reshape(N, out_ch, H * up, W * up)
1283
+ if noise is not None:
1284
+ x = x + noise
1285
+
1286
+ bias = None
1287
+ if self.bias is not None:
1288
+ bias = self.bias.to(dtype)
1289
+ if self.bscale != 1.0:
1290
+ bias = bias * self.bscale
1291
+
1292
+ if self.activation_type == 'linear': # Shortcut for output layer.
1293
+ x = bias_act.bias_act(
1294
+ x, bias, act='linear', clamp=self.conv_clamp, impl=impl)
1295
+ else:
1296
+ act_gain = self.act_gain * runtime_gain
1297
+ act_clamp = None
1298
+ if self.conv_clamp is not None:
1299
+ act_clamp = self.conv_clamp * runtime_gain
1300
+ x = bias_act.bias_act(x, bias,
1301
+ act=self.activation_type,
1302
+ gain=act_gain,
1303
+ clamp=act_clamp,
1304
+ impl=impl)
1305
+
1306
+ assert x.dtype == dtype
1307
+ assert style.dtype == torch.float32
1308
+ return x, style
1309
+
1310
+
1311
+ class DenseLayer(nn.Module):
1312
+ """Implements the dense layer."""
1313
+
1314
+ def __init__(self,
1315
+ in_channels,
1316
+ out_channels,
1317
+ add_bias,
1318
+ init_bias,
1319
+ use_wscale,
1320
+ wscale_gain,
1321
+ lr_mul,
1322
+ activation_type):
1323
+ """Initializes with layer settings.
1324
+
1325
+ Args:
1326
+ in_channels: Number of channels of the input tensor.
1327
+ out_channels: Number of channels of the output tensor.
1328
+ add_bias: Whether to add bias onto the fully-connected result.
1329
+ init_bias: The initial bias value before training.
1330
+ use_wscale: Whether to use weight scaling.
1331
+ wscale_gain: Gain factor for weight scaling.
1332
+ lr_mul: Learning multiplier for both weight and bias.
1333
+ activation_type: Type of activation.
1334
+ """
1335
+ super().__init__()
1336
+ self.in_channels = in_channels
1337
+ self.out_channels = out_channels
1338
+ self.add_bias = add_bias
1339
+ self.init_bias = init_bias
1340
+ self.use_wscale = use_wscale
1341
+ self.wscale_gain = wscale_gain
1342
+ self.lr_mul = lr_mul
1343
+ self.activation_type = activation_type
1344
+
1345
+ weight_shape = (out_channels, in_channels)
1346
+ wscale = wscale_gain / np.sqrt(in_channels)
1347
+ if use_wscale:
1348
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
1349
+ self.wscale = wscale * lr_mul
1350
+ else:
1351
+ self.weight = nn.Parameter(
1352
+ torch.randn(*weight_shape) * wscale / lr_mul)
1353
+ self.wscale = lr_mul
1354
+
1355
+ if add_bias:
1356
+ init_bias = np.float32(init_bias) / lr_mul
1357
+ self.bias = nn.Parameter(torch.full([out_channels], init_bias))
1358
+ self.bscale = lr_mul
1359
+ else:
1360
+ self.bias = None
1361
+
1362
+ def extra_repr(self):
1363
+ return (f'in_ch={self.in_channels}, '
1364
+ f'out_ch={self.out_channels}, '
1365
+ f'wscale_gain={self.wscale_gain:.3f}, '
1366
+ f'bias={self.add_bias}, '
1367
+ f'init_bias={self.init_bias}, '
1368
+ f'lr_mul={self.lr_mul:.3f}, '
1369
+ f'act={self.activation_type}')
1370
+
1371
+ def forward(self, x, impl='cuda'):
1372
+ dtype = x.dtype
1373
+
1374
+ if x.ndim != 2:
1375
+ x = x.flatten(start_dim=1)
1376
+
1377
+ weight = self.weight.to(dtype) * self.wscale
1378
+ bias = None
1379
+ if self.bias is not None:
1380
+ bias = self.bias.to(dtype)
1381
+ if self.bscale != 1.0:
1382
+ bias = bias * self.bscale
1383
+
1384
+ # Fast pass for linear activation.
1385
+ if self.activation_type == 'linear' and bias is not None:
1386
+ x = torch.addmm(bias.unsqueeze(0), x, weight.t())
1387
+ else:
1388
+ x = x.matmul(weight.t())
1389
+ x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl)
1390
+
1391
+ assert x.dtype == dtype
1392
+ return x
1393
+
1394
+ # pylint: enable=missing-function-docstring
models/stylegan3_generator.py ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of generator described in StyleGAN3.
3
+
4
+ Compared to that of StyleGAN2, the generator in StyleGAN3 controls the frequency
5
+ flow along with the convolutional layers growing.
6
+
7
+ Paper: https://arxiv.org/pdf/2106.12423.pdf
8
+
9
+ Official implementation: https://github.com/NVlabs/stylegan3
10
+ """
11
+
12
+ import numpy as np
13
+ import scipy.signal
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from third_party.stylegan3_official_ops import bias_act
20
+ from third_party.stylegan3_official_ops import filtered_lrelu
21
+ from third_party.stylegan3_official_ops import conv2d_gradfix
22
+ from .utils.ops import all_gather
23
+
24
+ __all__ = ['StyleGAN3Generator']
25
+
26
+ # Resolutions allowed.
27
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
28
+
29
+ # pylint: disable=missing-function-docstring
30
+
31
+ class StyleGAN3Generator(nn.Module):
32
+ """Defines the generator network in StyleGAN3.
33
+
34
+ NOTE: The synthesized images are with `RGB` channel order and pixel range
35
+ [-1, 1].
36
+
37
+ Settings for the mapping network:
38
+
39
+ (1) z_dim: Dimension of the input latent space, Z. (default: 512)
40
+ (2) w_dim: Dimension of the output latent space, W. (default: 512)
41
+ (3) repeat_w: Repeat w-code for different layers. (default: True)
42
+ (4) normalize_z: Whether to normalize the z-code. (default: True)
43
+ (5) mapping_layers: Number of layers of the mapping network. (default: 2)
44
+ (6) mapping_fmaps: Number of hidden channels of the mapping network.
45
+ (default: 512)
46
+ (7) mapping_lr_mul: Learning rate multiplier for the mapping network.
47
+ (default: 0.01)
48
+
49
+ Settings for conditional generation:
50
+
51
+ (1) label_dim: Dimension of the additional label for conditional generation.
52
+ In one-hot conditioning case, it is equal to the number of classes. If
53
+ set to 0, conditioning training will be disabled. (default: 0)
54
+ (2) embedding_dim: Dimension of the embedding space, if needed.
55
+ (default: 512)
56
+ (3) embedding_bias: Whether to add bias to embedding learning.
57
+ (default: True)
58
+ (4) embedding_lr_mul: Learning rate multiplier for the embedding learning.
59
+ (default: 1.0)
60
+ (5) normalize_embedding: Whether to normalize the embedding. (default: True)
61
+ (6) normalize_embedding_latent: Whether to normalize the embedding together
62
+ with the latent. (default: False)
63
+
64
+ Settings for the synthesis network:
65
+
66
+ (1) resolution: The resolution of the output image. (default: -1)
67
+ (2) image_channels: Number of channels of the output image. (default: 3)
68
+ (3) final_tanh: Whether to use `tanh` to control the final pixel range.
69
+ (default: False)
70
+ (4) output_scale: Factor to scaling the output image. (default: 0.25)
71
+ (5) num_layers: Number of synthesis layers, excluding the first positional
72
+ encoding layer and the last ToRGB layer. (default: 14)
73
+ (6) num_critical: Number of synthesis layers with critical sampling. These
74
+ layers are always set as top (with highest resolution) ones.
75
+ (7) fmaps_base: Factor to control number of feature maps for each layer.
76
+ (default: 32 << 10)
77
+ (8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
78
+ (9) kernel_size: Size of convolutional kernels. (default: 1)
79
+ (10) conv_clamp: A threshold to clamp the output of convolution layers to
80
+ avoid overflow under FP16 training. (default: None)
81
+ (11) first_cutoff: Cutoff frequency of the first layer. (default: 2)
82
+ (12) first_stopband: Stopband of the first layer. (default: 2 ** 2.1)
83
+ (13) last_stopband_rel: Stopband of the last layer, relative to the last
84
+ cutoff, which is `resolution / 2`. Concretely, `last_stopband` will be
85
+ equal to `resolution / 2 * last_stopband_rel`. (default: 2 ** 0.3)
86
+ (14) margin_size: Size of margin for each feature map. (default: 10)
87
+ (15) filter_size: Size of filter for upsampling and downsampling around the
88
+ activation. (default: 6)
89
+ (16) act_upsampling: Factor used to upsample the feature map before
90
+ activation for anti-aliasing. (default: 2)
91
+ (17) use_radial_filter: Whether to use radial filter for downsampling after
92
+ the activation. (default: False)
93
+ (18) eps: A small value to avoid divide overflow. (default: 1e-8)
94
+
95
+ Runtime settings:
96
+
97
+ (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for
98
+ training only. Set `None` to disable. (default: 0.998)
99
+ (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set
100
+ as `True`, the stats will be more accurate, yet the speed maybe a little
101
+ bit slower. (default: False)
102
+ (3) style_mixing_prob: Probability to perform style mixing as a training
103
+ regularization. Set `None` to disable. (default: None)
104
+ (4) trunc_psi: Truncation psi, set `None` to disable. (default: None)
105
+ (5) trunc_layers: Number of layers to perform truncation. (default: None)
106
+ (6) magnitude_moving_decay: Decay factor for updating `magnitude_ema` in
107
+ each `SynthesisLayer`, which is used for training only. Set `None` to
108
+ disable. (default: 0.999)
109
+ (7) update_ema: Whether to update `w_avg` in the `MappingNetwork` and
110
+ `magnitude_ema` in each `SynthesisLayer`. This field only takes effect
111
+ in `training` model. (default: False)
112
+ (8) fp16_res: Layers at resolution higher than (or equal to) this field will
113
+ use `float16` precision for computation. This is merely used for
114
+ acceleration. If set as `None`, all layers will use `float32` by
115
+ default. (default: None)
116
+ (9) impl: Implementation mode of some particular ops, e.g., `filtering`,
117
+ `bias_act`, etc. `cuda` means using the official CUDA implementation
118
+ from StyleGAN3, while `ref` means using the native PyTorch ops.
119
+ (default: `cuda`)
120
+ """
121
+
122
+ def __init__(self,
123
+ # Settings for mapping network.
124
+ z_dim=512,
125
+ w_dim=512,
126
+ repeat_w=True,
127
+ normalize_z=True,
128
+ mapping_layers=2,
129
+ mapping_fmaps=512,
130
+ mapping_lr_mul=0.01,
131
+ # Settings for conditional generation.
132
+ label_dim=0,
133
+ embedding_dim=512,
134
+ embedding_bias=True,
135
+ embedding_lr_mul=1.0,
136
+ normalize_embedding=True,
137
+ normalize_embedding_latent=False,
138
+ # Settings for synthesis network.
139
+ resolution=-1,
140
+ image_channels=3,
141
+ final_tanh=False,
142
+ output_scale=0.25,
143
+ num_layers=14,
144
+ num_critical=2,
145
+ fmaps_base=32 << 10,
146
+ fmaps_max=512,
147
+ kernel_size=1,
148
+ conv_clamp=256,
149
+ first_cutoff=2,
150
+ first_stopband=2 ** 2.1,
151
+ last_stopband_rel=2 ** 0.3,
152
+ margin_size=10,
153
+ filter_size=6,
154
+ act_upsampling=2,
155
+ use_radial_filter=False,
156
+ eps=1e-8):
157
+ """Initializes with basic settings."""
158
+ super().__init__()
159
+
160
+ if resolution not in _RESOLUTIONS_ALLOWED:
161
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
162
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
163
+
164
+ self.z_dim = z_dim
165
+ self.w_dim = w_dim
166
+ self.repeat_w = repeat_w
167
+ self.normalize_z = normalize_z
168
+ self.mapping_layers = mapping_layers
169
+ self.mapping_fmaps = mapping_fmaps
170
+ self.mapping_lr_mul = mapping_lr_mul
171
+
172
+ self.label_dim = label_dim
173
+ self.embedding_dim = embedding_dim
174
+ self.embedding_bias = embedding_bias
175
+ self.embedding_lr_mul = embedding_lr_mul
176
+ self.normalize_embedding = normalize_embedding
177
+ self.normalize_embedding_latent = normalize_embedding_latent
178
+
179
+ self.resolution = resolution
180
+ self.image_channels = image_channels
181
+ self.final_tanh = final_tanh
182
+ self.output_scale = output_scale
183
+ self.num_layers = num_layers + 2 # Including InputLayer and ToRGBLayer.
184
+ self.num_critical = num_critical
185
+ self.fmaps_base = fmaps_base
186
+ self.fmaps_max = fmaps_max
187
+ self.kernel_size = kernel_size
188
+ self.conv_clamp = conv_clamp
189
+ self.first_cutoff = first_cutoff
190
+ self.first_stopband = first_stopband
191
+ self.last_stopband_rel = last_stopband_rel
192
+ self.margin_size = margin_size
193
+ self.filter_size = filter_size
194
+ self.act_upsampling = act_upsampling
195
+ self.use_radial_filter = use_radial_filter
196
+ self.eps = eps
197
+
198
+ # Dimension of latent space, which is convenient for sampling.
199
+ self.latent_dim = (z_dim,)
200
+
201
+ self.mapping = MappingNetwork(
202
+ input_dim=z_dim,
203
+ output_dim=w_dim,
204
+ num_outputs=self.num_layers,
205
+ repeat_output=repeat_w,
206
+ normalize_input=normalize_z,
207
+ num_layers=mapping_layers,
208
+ hidden_dim=mapping_fmaps,
209
+ lr_mul=mapping_lr_mul,
210
+ label_dim=label_dim,
211
+ embedding_dim=embedding_dim,
212
+ embedding_bias=embedding_bias,
213
+ embedding_lr_mul=embedding_lr_mul,
214
+ normalize_embedding=normalize_embedding,
215
+ normalize_embedding_latent=normalize_embedding_latent,
216
+ eps=eps)
217
+
218
+ # This is used for truncation trick.
219
+ if self.repeat_w:
220
+ self.register_buffer('w_avg', torch.zeros(w_dim))
221
+ else:
222
+ self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim))
223
+
224
+ self.synthesis = SynthesisNetwork(resolution=resolution,
225
+ w_dim=w_dim,
226
+ image_channels=image_channels,
227
+ final_tanh=final_tanh,
228
+ output_scale=output_scale,
229
+ num_layers=num_layers,
230
+ num_critical=num_critical,
231
+ fmaps_base=fmaps_base,
232
+ fmaps_max=fmaps_max,
233
+ kernel_size=kernel_size,
234
+ conv_clamp=conv_clamp,
235
+ first_cutoff=first_cutoff,
236
+ first_stopband=first_stopband,
237
+ last_stopband_rel=last_stopband_rel,
238
+ margin_size=margin_size,
239
+ filter_size=filter_size,
240
+ act_upsampling=act_upsampling,
241
+ use_radial_filter=use_radial_filter,
242
+ eps=eps)
243
+
244
+ self.var_mapping = {'w_avg': 'mapping.w_avg'}
245
+ for key, val in self.mapping.var_mapping.items():
246
+ self.var_mapping[f'mapping.{key}'] = f'mapping.{val}'
247
+ for key, val in self.synthesis.var_mapping.items():
248
+ self.var_mapping[f'synthesis.{key}'] = f'synthesis.{val}'
249
+
250
+ def set_space_of_latent(self, space_of_latent):
251
+ """Sets the space to which the latent code belong.
252
+
253
+ See `SynthesisNetwork` for more details.
254
+ """
255
+ self.synthesis.set_space_of_latent(space_of_latent)
256
+
257
+ def forward(self,
258
+ z,
259
+ label=None,
260
+ w_moving_decay=0.998,
261
+ sync_w_avg=False,
262
+ style_mixing_prob=None,
263
+ trunc_psi=None,
264
+ trunc_layers=None,
265
+ magnitude_moving_decay=0.999,
266
+ update_ema=False,
267
+ fp16_res=None,
268
+ impl='cuda'):
269
+ """Connects mapping network and synthesis network.
270
+
271
+ This forward function will also update the average `w_code`, perform
272
+ style mixing as a training regularizer, and do truncation trick, which
273
+ is specially designed for inference.
274
+
275
+ Concretely, the truncation trick acts as follows:
276
+
277
+ For layers in range [0, truncation_layers), the truncated w-code is
278
+ computed as
279
+
280
+ w_new = w_avg + (w - w_avg) * truncation_psi
281
+
282
+ To disable truncation, please set
283
+
284
+ (1) truncation_psi = 1.0 (None) OR
285
+ (2) truncation_layers = 0 (None)
286
+ """
287
+
288
+ mapping_results = self.mapping(z, label, impl=impl)
289
+
290
+ w = mapping_results['w']
291
+ if self.training and update_ema and w_moving_decay is not None:
292
+ if sync_w_avg:
293
+ batch_w_avg = all_gather(w.detach()).mean(dim=0)
294
+ else:
295
+ batch_w_avg = w.detach().mean(dim=0)
296
+ self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay))
297
+
298
+ wp = mapping_results.pop('wp')
299
+ if self.training and style_mixing_prob is not None:
300
+ if np.random.uniform() < style_mixing_prob:
301
+ new_z = torch.randn_like(z)
302
+ new_wp = self.mapping(new_z, label, impl=impl)['wp']
303
+ mixing_cutoff = np.random.randint(1, self.num_layers)
304
+ wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:]
305
+
306
+ if not self.training:
307
+ trunc_psi = 1.0 if trunc_psi is None else trunc_psi
308
+ trunc_layers = 0 if trunc_layers is None else trunc_layers
309
+ if trunc_psi < 1.0 and trunc_layers > 0:
310
+ w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers]
311
+ wp[:, :trunc_layers] = w_avg.lerp(
312
+ wp[:, :trunc_layers], trunc_psi)
313
+
314
+ synthesis_results = self.synthesis(
315
+ wp,
316
+ magnitude_moving_decay=magnitude_moving_decay,
317
+ update_ema=update_ema,
318
+ fp16_res=fp16_res,
319
+ impl=impl)
320
+
321
+ return {**mapping_results, **synthesis_results}
322
+
323
+
324
+ class MappingNetwork(nn.Module):
325
+ """Implements the latent space mapping network.
326
+
327
+ Basically, this network executes several dense layers in sequence, and the
328
+ label embedding if needed.
329
+ """
330
+
331
+ def __init__(self,
332
+ input_dim,
333
+ output_dim,
334
+ num_outputs,
335
+ repeat_output,
336
+ normalize_input,
337
+ num_layers,
338
+ hidden_dim,
339
+ lr_mul,
340
+ label_dim,
341
+ embedding_dim,
342
+ embedding_bias,
343
+ embedding_lr_mul,
344
+ normalize_embedding,
345
+ normalize_embedding_latent,
346
+ eps):
347
+ super().__init__()
348
+
349
+ self.input_dim = input_dim
350
+ self.output_dim = output_dim
351
+ self.num_outputs = num_outputs
352
+ self.repeat_output = repeat_output
353
+ self.normalize_input = normalize_input
354
+ self.num_layers = num_layers
355
+ self.hidden_dim = hidden_dim
356
+ self.lr_mul = lr_mul
357
+ self.label_dim = label_dim
358
+ self.embedding_dim = embedding_dim
359
+ self.embedding_bias = embedding_bias
360
+ self.embedding_lr_mul = embedding_lr_mul
361
+ self.normalize_embedding = normalize_embedding
362
+ self.normalize_embedding_latent = normalize_embedding_latent
363
+ self.eps = eps
364
+
365
+ self.var_mapping = {}
366
+
367
+ self.norm = PixelNormLayer(dim=1, eps=eps)
368
+
369
+ if self.label_dim > 0:
370
+ input_dim = input_dim + embedding_dim
371
+ self.embedding = DenseLayer(in_channels=label_dim,
372
+ out_channels=embedding_dim,
373
+ init_weight_std=1.0,
374
+ add_bias=embedding_bias,
375
+ init_bias=0.0,
376
+ lr_mul=embedding_lr_mul,
377
+ activation_type='linear')
378
+ self.var_mapping['embedding.weight'] = 'embed.weight'
379
+ if self.embedding_bias:
380
+ self.var_mapping['embedding.bias'] = 'embed.bias'
381
+
382
+ if num_outputs is not None and not repeat_output:
383
+ output_dim = output_dim * num_outputs
384
+ for i in range(num_layers):
385
+ in_channels = (input_dim if i == 0 else hidden_dim)
386
+ out_channels = (output_dim if i == (num_layers - 1) else hidden_dim)
387
+ self.add_module(f'dense{i}',
388
+ DenseLayer(in_channels=in_channels,
389
+ out_channels=out_channels,
390
+ init_weight_std=1.0,
391
+ add_bias=True,
392
+ init_bias=0.0,
393
+ lr_mul=lr_mul,
394
+ activation_type='lrelu'))
395
+ self.var_mapping[f'dense{i}.weight'] = f'fc{i}.weight'
396
+ self.var_mapping[f'dense{i}.bias'] = f'fc{i}.bias'
397
+
398
+ def forward(self, z, label=None, impl='cuda'):
399
+ if z.ndim != 2 or z.shape[1] != self.input_dim:
400
+ raise ValueError(f'Input latent code should be with shape '
401
+ f'[batch_size, input_dim], where '
402
+ f'`input_dim` equals to {self.input_dim}!\n'
403
+ f'But `{z.shape}` is received!')
404
+ if self.normalize_input:
405
+ z = self.norm(z)
406
+
407
+ if self.label_dim > 0:
408
+ if label is None:
409
+ raise ValueError(f'Model requires an additional label '
410
+ f'(with dimension {self.label_dim}) as input, '
411
+ f'but no label is received!')
412
+ if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
413
+ raise ValueError(f'Input label should be with shape '
414
+ f'[batch_size, label_dim], where '
415
+ f'`batch_size` equals to that of '
416
+ f'latent codes ({z.shape[0]}) and '
417
+ f'`label_dim` equals to {self.label_dim}!\n'
418
+ f'But `{label.shape}` is received!')
419
+ label = label.to(dtype=torch.float32)
420
+ embedding = self.embedding(label, impl=impl)
421
+ if self.normalize_embedding:
422
+ embedding = self.norm(embedding)
423
+ w = torch.cat((z, embedding), dim=1)
424
+ else:
425
+ w = z
426
+
427
+ if self.label_dim > 0 and self.normalize_embedding_latent:
428
+ w = self.norm(w)
429
+
430
+ for i in range(self.num_layers):
431
+ w = getattr(self, f'dense{i}')(w, impl=impl)
432
+
433
+ wp = None
434
+ if self.num_outputs is not None:
435
+ if self.repeat_output:
436
+ wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1))
437
+ else:
438
+ wp = w.reshape(-1, self.num_outputs, self.output_dim)
439
+
440
+ results = {
441
+ 'z': z,
442
+ 'label': label,
443
+ 'w': w,
444
+ 'wp': wp,
445
+ }
446
+ if self.label_dim > 0:
447
+ results['embedding'] = embedding
448
+ return results
449
+
450
+
451
+ class SynthesisNetwork(nn.Module):
452
+ """Implements the image synthesis network.
453
+
454
+ Basically, this network executes several convolutional layers in sequence.
455
+ """
456
+
457
+ def __init__(self,
458
+ resolution,
459
+ w_dim,
460
+ image_channels,
461
+ final_tanh,
462
+ output_scale,
463
+ num_layers,
464
+ num_critical,
465
+ fmaps_base,
466
+ fmaps_max,
467
+ kernel_size,
468
+ conv_clamp,
469
+ first_cutoff,
470
+ first_stopband,
471
+ last_stopband_rel,
472
+ margin_size,
473
+ filter_size,
474
+ act_upsampling,
475
+ use_radial_filter,
476
+ eps):
477
+ super().__init__()
478
+
479
+ self.resolution = resolution
480
+ self.w_dim = w_dim
481
+ self.image_channels = image_channels
482
+ self.final_tanh = final_tanh
483
+ self.output_scale = output_scale
484
+ self.num_layers = num_layers
485
+ self.num_critical = num_critical
486
+ self.fmaps_base = fmaps_base
487
+ self.fmaps_max = fmaps_max
488
+ self.kernel_size = kernel_size
489
+ self.conv_clamp = conv_clamp
490
+ self.first_cutoff = first_cutoff
491
+ self.first_stopband = first_stopband
492
+ self.last_stopband_rel = last_stopband_rel
493
+ self.margin_size = margin_size
494
+ self.filter_size = filter_size
495
+ self.act_upsampling = act_upsampling
496
+ self.use_radial_filter = use_radial_filter
497
+ self.eps = eps
498
+
499
+ self.var_mapping = {}
500
+
501
+ # Get layer settings.
502
+ last_cutoff = resolution / 2
503
+ last_stopband = last_cutoff * last_stopband_rel
504
+ layer_indices = np.arange(num_layers + 1)
505
+ exponents = np.minimum(layer_indices / (num_layers - num_critical), 1)
506
+ cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents
507
+ stopbands = (
508
+ first_stopband * (last_stopband / first_stopband) ** exponents)
509
+ sampling_rates = np.exp2(np.ceil(np.log2(
510
+ np.minimum(stopbands * 2, self.resolution))))
511
+ sampling_rates = np.int64(sampling_rates)
512
+ half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs
513
+ sizes = sampling_rates + margin_size * 2
514
+ sizes[-2:] = resolution
515
+ sizes = np.int64(sizes)
516
+ channels = np.rint(np.minimum((fmaps_base / 2) / cutoffs, fmaps_max))
517
+ channels[-1] = image_channels
518
+ channels = np.int64(channels)
519
+
520
+ self.cutoffs = cutoffs
521
+ self.stopbands = stopbands
522
+ self.sampling_rates = sampling_rates
523
+ self.half_widths = half_widths
524
+ self.sizes = sizes
525
+ self.channels = channels
526
+
527
+ # Input layer, with positional encoding.
528
+ self.early_layer = InputLayer(w_dim=w_dim,
529
+ channels=channels[0],
530
+ size=sizes[0],
531
+ sampling_rate=sampling_rates[0],
532
+ cutoff=cutoffs[0])
533
+ self.var_mapping['early_layer.weight'] = 'input.weight'
534
+ self.var_mapping['early_layer.affine.weight'] = 'input.affine.weight'
535
+ self.var_mapping['early_layer.affine.bias'] = 'input.affine.bias'
536
+ self.var_mapping['early_layer.transform'] = 'input.transform'
537
+ self.var_mapping['early_layer.frequency'] = 'input.freqs'
538
+ self.var_mapping['early_layer.phase'] = 'input.phases'
539
+
540
+ # Convolutional layers.
541
+ for idx in range(num_layers + 1):
542
+ # Position related settings.
543
+ if idx < num_layers:
544
+ kernel_size = self.kernel_size
545
+ demodulate = True
546
+ act_upsampling = self.act_upsampling
547
+ else: # ToRGB layer.
548
+ kernel_size = 1
549
+ demodulate = False
550
+ act_upsampling = 1
551
+ if idx < num_layers - num_critical: # Non-critical sampling.
552
+ use_radial_filter = self.use_radial_filter
553
+ else: # Critical sampling.
554
+ use_radial_filter = False
555
+
556
+ prev_idx = max(idx - 1, 0)
557
+ layer_name = f'layer{idx}'
558
+ official_layer_name = f'L{idx}_{sizes[idx]}_{channels[idx]}'
559
+ self.add_module(
560
+ layer_name,
561
+ SynthesisLayer(in_channels=channels[prev_idx],
562
+ out_channels=channels[idx],
563
+ w_dim=w_dim,
564
+ kernel_size=kernel_size,
565
+ demodulate=demodulate,
566
+ eps=eps,
567
+ conv_clamp=conv_clamp,
568
+ in_size=sizes[prev_idx],
569
+ out_size=sizes[idx],
570
+ in_sampling_rate=sampling_rates[prev_idx],
571
+ out_sampling_rate=sampling_rates[idx],
572
+ in_cutoff=cutoffs[prev_idx],
573
+ out_cutoff=cutoffs[idx],
574
+ in_half_width=half_widths[prev_idx],
575
+ out_half_width=half_widths[idx],
576
+ filter_size=filter_size,
577
+ use_radial_filter=use_radial_filter,
578
+ act_upsampling=act_upsampling))
579
+
580
+ self.var_mapping[f'{layer_name}.magnitude_ema'] = (
581
+ f'{official_layer_name}.magnitude_ema')
582
+ self.var_mapping[f'{layer_name}.conv.weight'] = (
583
+ f'{official_layer_name}.weight')
584
+ self.var_mapping[f'{layer_name}.conv.style.weight'] = (
585
+ f'{official_layer_name}.affine.weight')
586
+ self.var_mapping[f'{layer_name}.conv.style.bias'] = (
587
+ f'{official_layer_name}.affine.bias')
588
+ self.var_mapping[f'{layer_name}.filter.bias'] = (
589
+ f'{official_layer_name}.bias')
590
+ if idx < num_layers: # ToRGB layer does not need filters.
591
+ self.var_mapping[f'{layer_name}.filter.up_filter'] = (
592
+ f'{official_layer_name}.up_filter')
593
+ self.var_mapping[f'{layer_name}.filter.down_filter'] = (
594
+ f'{official_layer_name}.down_filter')
595
+
596
+ def set_space_of_latent(self, space_of_latent):
597
+ """Sets the space to which the latent code belong.
598
+
599
+ This function is particularly used for choosing how to inject the latent
600
+ code into the convolutional layers. The original generator will take a
601
+ W-Space code and apply it for style modulation after an affine
602
+ transformation. But, sometimes, it may need to directly feed an already
603
+ affine-transformed code into the convolutional layer, e.g., when
604
+ training an encoder for GAN inversion. We term the transformed space as
605
+ Style Space (or Y-Space). This function is designed to tell the
606
+ convolutional layers how to use the input code.
607
+
608
+ Args:
609
+ space_of_latent: The space to which the latent code belong. Case
610
+ insensitive. Support `W` and `Y`.
611
+ """
612
+ space_of_latent = space_of_latent.upper()
613
+ for module in self.modules():
614
+ if isinstance(module, ModulateConvLayer):
615
+ setattr(module, 'space_of_latent', space_of_latent)
616
+
617
+ def forward(self,
618
+ wp,
619
+ magnitude_moving_decay=0.999,
620
+ update_ema=False,
621
+ fp16_res=None,
622
+ impl='cuda'):
623
+ results = {'wp': wp}
624
+
625
+ x = self.early_layer(wp[:, 0])
626
+ for idx, sampling_rate in enumerate(self.sampling_rates):
627
+ if fp16_res is not None and sampling_rate >= fp16_res:
628
+ x = x.to(torch.float16)
629
+ layer = getattr(self, f'layer{idx}')
630
+ x, style = layer(x, wp[:, idx + 1],
631
+ magnitude_moving_decay=magnitude_moving_decay,
632
+ update_ema=update_ema,
633
+ impl=impl)
634
+ results[f'style{idx}'] = style
635
+
636
+ if self.output_scale != 1:
637
+ x = x * self.output_scale
638
+ x = x.to(torch.float32)
639
+ if self.final_tanh:
640
+ x = torch.tanh(x)
641
+ results['image'] = x
642
+ return results
643
+
644
+
645
+ class PixelNormLayer(nn.Module):
646
+ """Implements pixel-wise feature vector normalization layer."""
647
+
648
+ def __init__(self, dim, eps):
649
+ super().__init__()
650
+ self.dim = dim
651
+ self.eps = eps
652
+
653
+ def extra_repr(self):
654
+ return f'dim={self.dim}, epsilon={self.eps}'
655
+
656
+ def forward(self, x):
657
+ scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
658
+ return x * scale
659
+
660
+
661
+ class InputLayer(nn.Module):
662
+ """Implements the input layer with positional encoding.
663
+
664
+ Basically, this block outputs a feature map with shape
665
+ `(channels, size, size)` based on the coordinate information.
666
+ `sampling_rate` and `cutoff` are used to control the coordinate range and
667
+ strength respectively.
668
+
669
+ For a low-pass filter, `cutoff` is the same as the `bandwidth`.
670
+ The initial frequency of the starting feature map is controlled by the
671
+ positional encoding `sin(2 * pi * x)`, where
672
+ `x = trans(coord) * frequency + phase`. We would like to introduce rich
673
+ information (i.e. frequencies), but keep all frequencies lower than
674
+ stopband, which is `sampling_rate / 2`.
675
+
676
+ Besides, this layer also supports learning a transformation from the latent
677
+ code w, and providing a customized transformation for inference. Please
678
+ use the buffer `transform`.
679
+
680
+ NOTE: `size` is different from `sampling_rate`. `sampling_rate` is the
681
+ actual size of the current stage, which determines the maximum frequency
682
+ that the feature maps can hold. `size` is the actual height and width of the
683
+ current feature map, including the extended border.
684
+ """
685
+
686
+ def __init__(self, w_dim, channels, size, sampling_rate, cutoff):
687
+ super().__init__()
688
+
689
+ self.w_dim = w_dim
690
+ self.channels = channels
691
+ self.size = size
692
+ self.sampling_rate = sampling_rate
693
+ self.cutoff = cutoff
694
+
695
+ # Coordinate of the entire feature map, with resolution (size, size).
696
+ # The coordinate range for the central (sampling_rate, sampling_rate)
697
+ # region is set as (-0.0, 0.5), which extends to the remaining region.
698
+ theta = torch.eye(2, 3)
699
+ theta[0, 0] = 0.5 / sampling_rate * size
700
+ theta[1, 1] = 0.5 / sampling_rate * size
701
+ grid = F.affine_grid(theta=theta.unsqueeze(0),
702
+ size=(1, 1, size, size),
703
+ align_corners=False)
704
+ self.register_buffer('grid', grid)
705
+
706
+ # Draw random frequency from a uniform 2D disc for each channel
707
+ # regarding X and Y dimension. And also draw a random phase for each
708
+ # channel. Accordingly, each channel has three pre-defined parameters,
709
+ # which are X-frequency, Y-frequency, and phase.
710
+ frequency = torch.randn(channels, 2)
711
+ radius = frequency.square().sum(dim=1, keepdim=True).sqrt()
712
+ frequency = frequency / (radius * radius.square().exp().pow(0.25))
713
+ frequency = frequency * cutoff
714
+ self.register_buffer('frequency', frequency)
715
+ phase = torch.rand(channels) - 0.5
716
+ self.register_buffer('phase', phase)
717
+
718
+ # This layer is used to map the latent code w to transform factors,
719
+ # with order: cos(angle), sin(angle), transpose_x, transpose_y.
720
+ self.affine = DenseLayer(in_channels=w_dim,
721
+ out_channels=4,
722
+ init_weight_std=0.0,
723
+ add_bias=True,
724
+ init_bias=(1, 0, 0, 0),
725
+ lr_mul=1.0,
726
+ activation_type='linear')
727
+
728
+ # It is possible to use this buffer to customize the transform of the
729
+ # output synthesis.
730
+ self.register_buffer('transform', torch.eye(3))
731
+
732
+ # Use 1x1 conv to convert positional encoding to features.
733
+ self.weight = nn.Parameter(torch.randn(channels, channels))
734
+ self.weight_scale = 1 / np.sqrt(channels)
735
+
736
+ def extra_repr(self):
737
+ return (f'channels={self.channels}, '
738
+ f'size={self.size}, '
739
+ f'sampling_rate={self.sampling_rate}, '
740
+ f'cutoff={self.cutoff:.3f}, ')
741
+
742
+ def forward(self, w):
743
+ batch = w.shape[0]
744
+
745
+ # Get transformation matrix.
746
+ # Factor controlled by latent code.
747
+ transformation_factor = self.affine(w)
748
+ # Ensure the range of cosine and sine value (first two dimension).
749
+ _norm = transformation_factor[:, :2].norm(dim=1, keepdim=True)
750
+ transformation_factor = transformation_factor / _norm
751
+ # Rotation.
752
+ rotation = torch.eye(3, device=w.device).unsqueeze(0)
753
+ rotation = rotation.repeat((batch, 1, 1))
754
+ rotation[:, 0, 0] = transformation_factor[:, 0]
755
+ rotation[:, 0, 1] = -transformation_factor[:, 1]
756
+ rotation[:, 1, 0] = transformation_factor[:, 1]
757
+ rotation[:, 1, 1] = transformation_factor[:, 0]
758
+ # Translation.
759
+ translation = torch.eye(3, device=w.device).unsqueeze(0)
760
+ translation = translation.repeat((batch, 1, 1))
761
+ translation[:, 0, 2] = -transformation_factor[:, 2]
762
+ translation[:, 1, 2] = -transformation_factor[:, 3]
763
+ # Customized transformation.
764
+ transform = rotation @ translation @ self.transform.unsqueeze(0)
765
+
766
+ # Transform frequency and shift, which is equivalent to transforming
767
+ # the coordinate. For example, given a coordinate, X, we would like to
768
+ # first transform it with the rotation matrix, R, and the translation
769
+ # matrix, T, as X' = RX + T. Then, we will apply frequency, f, and
770
+ # phase, p, with sin(2 * pi * (fX' + p)). Natively, we have
771
+ # fX' + p = f(RX + T) + p = (fR)X + (fT + p)
772
+ frequency = self.frequency.unsqueeze(0) @ transform[:, :2, :2] # [NC2]
773
+ phase = self.frequency.unsqueeze(0) @ transform[:, :2, 2:] # [NC]
774
+ phase = phase.squeeze(2) + self.phase.unsqueeze(0) # [NC]
775
+
776
+ # Positional encoding.
777
+ x = self.grid # [NHW2]
778
+ x = x.unsqueeze(3) # [NHW12]
779
+ x = x @ frequency.transpose(1, 2).unsqueeze(1).unsqueeze(2) # [NHW1C]
780
+ x = x.squeeze(3) # [NHWC]
781
+ x = x + phase.unsqueeze(1).unsqueeze(2) # [NHWC]
782
+ x = torch.sin(2 * np.pi * x) # [NHWC]
783
+
784
+ # Dampen out-of-band frequency that may be introduced by the customized
785
+ # transform `self.transform`.
786
+ frequency_norm = frequency.norm(dim=2)
787
+ stopband = self.sampling_rate / 2
788
+ factor = (frequency_norm - self.cutoff) / (stopband - self.cutoff)
789
+ amplitude = (1 - factor).clamp(0, 1) # [NC]
790
+ x = x * amplitude.unsqueeze(1).unsqueeze(2) # [NHWC]
791
+
792
+ # Project positional encoding to features.
793
+ weight = self.weight * self.weight_scale
794
+ x = x @ weight.t()
795
+
796
+ return x.permute(0, 3, 1, 2).contiguous()
797
+
798
+
799
+ class SynthesisLayer(nn.Module):
800
+ """Implements the synthesis layer.
801
+
802
+ Each synthesis layer (including ToRGB layer) consists of a
803
+ `ModulateConvLayer` and a `FilteringActLayer`. Besides, this layer will
804
+ trace the magnitude (norm) of the input feature map, and update the
805
+ statistic with `magnitude_moving_decay`.
806
+ """
807
+
808
+ def __init__(self,
809
+ # Settings for modulated convolution.
810
+ in_channels,
811
+ out_channels,
812
+ w_dim,
813
+ kernel_size,
814
+ demodulate,
815
+ eps,
816
+ conv_clamp,
817
+ # Settings for filtering activation.
818
+ in_size,
819
+ out_size,
820
+ in_sampling_rate,
821
+ out_sampling_rate,
822
+ in_cutoff,
823
+ out_cutoff,
824
+ in_half_width,
825
+ out_half_width,
826
+ filter_size,
827
+ use_radial_filter,
828
+ act_upsampling):
829
+ """Initializes with layer settings.
830
+
831
+ Args:
832
+ in_channels: Number of channels of the input tensor.
833
+ out_channels: Number of channels of the output tensor.
834
+ w_dim: Dimension of W space for style modulation.
835
+ kernel_size: Size of the convolutional kernels.
836
+ demodulate: Whether to perform style demodulation.
837
+ eps: A small value to avoid divide overflow.
838
+ conv_clamp: A threshold to clamp the output of convolution layers to
839
+ avoid overflow under FP16 training.
840
+ in_size: Size of the input feature map, i.e., height and width.
841
+ out_size: Size of the output feature map, i.e., height and width.
842
+ in_sampling_rate: Sampling rate of the input feature map. Different
843
+ from `in_size` that includes extended border, this field
844
+ controls the actual maximum frequency that can be represented
845
+ by the feature map.
846
+ out_sampling_rate: Sampling rate of the output feature map.
847
+ in_cutoff: Cutoff frequency of the input feature map.
848
+ out_cutoff: Cutoff frequency of the output feature map.
849
+ in_half_width: Half-width of the transition band of the input
850
+ feature map.
851
+ out_half_width: Half-width of the transition band of the output
852
+ feature map.
853
+ filter_size: Size of the filter used in this layer.
854
+ use_radial_filter: Whether to use radial filter.
855
+ act_upsampling: Upsampling factor used before the activation.
856
+ `1` means do not wrap upsampling and downsampling around the
857
+ activation.
858
+ """
859
+ super().__init__()
860
+
861
+ self.in_channels = in_channels
862
+ self.out_channels = out_channels
863
+ self.w_dim = w_dim
864
+ self.kernel_size = kernel_size
865
+ self.demodulate = demodulate
866
+ self.eps = eps
867
+ self.conv_clamp = conv_clamp
868
+
869
+ self.in_size = in_size
870
+ self.out_size = out_size
871
+ self.in_sampling_rate = in_sampling_rate
872
+ self.out_sampling_rate = out_sampling_rate
873
+ self.in_cutoff = in_cutoff
874
+ self.out_cutoff = out_cutoff
875
+ self.in_half_width = in_half_width
876
+ self.out_half_width = out_half_width
877
+ self.filter_size = filter_size
878
+ self.use_radial_filter = use_radial_filter
879
+ self.act_upsampling = act_upsampling
880
+
881
+ self.conv = ModulateConvLayer(in_channels=in_channels,
882
+ out_channels=out_channels,
883
+ w_dim=w_dim,
884
+ kernel_size=kernel_size,
885
+ demodulate=demodulate,
886
+ eps=eps)
887
+ self.register_buffer('magnitude_ema', torch.ones(()))
888
+ self.filter = FilteringActLayer(out_channels=out_channels,
889
+ in_size=in_size,
890
+ out_size=out_size,
891
+ in_sampling_rate=in_sampling_rate,
892
+ out_sampling_rate=out_sampling_rate,
893
+ in_cutoff=in_cutoff,
894
+ out_cutoff=out_cutoff,
895
+ in_half_width=in_half_width,
896
+ out_half_width=out_half_width,
897
+ filter_size=filter_size,
898
+ use_radial_filter=use_radial_filter,
899
+ conv_padding=self.conv.padding,
900
+ act_upsampling=act_upsampling)
901
+
902
+ def extra_repr(self):
903
+ return f'conv_clamp={self.conv_clamp}'
904
+
905
+ def forward(self,
906
+ x,
907
+ w,
908
+ magnitude_moving_decay=0.999,
909
+ update_ema=False,
910
+ impl='cuda'):
911
+ if self.training and update_ema and magnitude_moving_decay is not None:
912
+ magnitude = x.detach().to(torch.float32).square().mean()
913
+ self.magnitude_ema.copy_(
914
+ magnitude.lerp(self.magnitude_ema, magnitude_moving_decay))
915
+
916
+ input_gain = self.magnitude_ema.rsqrt()
917
+ x, style = self.conv(x, w, gain=input_gain, impl=impl)
918
+ if self.act_upsampling > 1:
919
+ x = self.filter(x, np.sqrt(2), 0.2, self.conv_clamp, impl=impl)
920
+ else:
921
+ x = self.filter(x, 1, 1, self.conv_clamp, impl=impl)
922
+
923
+ return x, style
924
+
925
+
926
+ class ModulateConvLayer(nn.Module):
927
+ """Implements the convolutional layer with style modulation.
928
+
929
+ Different from the one introduced in StyleGAN2, this layer has following
930
+ changes:
931
+
932
+ (1) fusing `conv` and `style modulation` into one op by default
933
+ (2) NOT adding a noise onto the output feature map.
934
+ (3) NOT activating the feature map, which is moved to `FilteringActLayer`.
935
+ """
936
+
937
+ def __init__(self,
938
+ in_channels,
939
+ out_channels,
940
+ w_dim,
941
+ kernel_size,
942
+ demodulate,
943
+ eps):
944
+ """Initializes with layer settings.
945
+
946
+ Args:
947
+ in_channels: Number of channels of the input tensor.
948
+ out_channels: Number of channels of the output tensor.
949
+ w_dim: Dimension of W space for style modulation.
950
+ kernel_size: Size of the convolutional kernels.
951
+ demodulate: Whether to perform style demodulation.
952
+ eps: A small value to avoid divide overflow.
953
+ """
954
+ super().__init__()
955
+
956
+ self.in_channels = in_channels
957
+ self.out_channels = out_channels
958
+ self.w_dim = w_dim
959
+ self.kernel_size = kernel_size
960
+ self.demodulate = demodulate
961
+ self.eps = eps
962
+
963
+ self.space_of_latent = 'W'
964
+
965
+ # Set up weight.
966
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
967
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
968
+ self.wscale = 1.0 / np.sqrt(kernel_size * kernel_size * in_channels)
969
+ self.padding = kernel_size - 1
970
+
971
+ # Set up style.
972
+ self.style = DenseLayer(in_channels=w_dim,
973
+ out_channels=in_channels,
974
+ init_weight_std=1.0,
975
+ add_bias=True,
976
+ init_bias=1.0,
977
+ lr_mul=1.0,
978
+ activation_type='linear')
979
+
980
+ def extra_repr(self):
981
+ return (f'in_ch={self.in_channels}, '
982
+ f'out_ch={self.out_channels}, '
983
+ f'ksize={self.kernel_size}, '
984
+ f'demodulate={self.demodulate}')
985
+
986
+ def forward_style(self, w, impl='cuda'):
987
+ """Gets style code from the given input.
988
+
989
+ More specifically, if the input is from W-Space, it will be projected by
990
+ an affine transformation. If it is from the Style Space (Y-Space), no
991
+ operation is required.
992
+
993
+ NOTE: For codes from Y-Space, we use slicing to make sure the dimension
994
+ is correct, in case that the code is padded before fed into this layer.
995
+ """
996
+ space_of_latent = self.space_of_latent.upper()
997
+ if space_of_latent == 'W':
998
+ if w.ndim != 2 or w.shape[1] != self.w_dim:
999
+ raise ValueError(f'The input tensor should be with shape '
1000
+ f'[batch_size, w_dim], where '
1001
+ f'`w_dim` equals to {self.w_dim}!\n'
1002
+ f'But `{w.shape}` is received!')
1003
+ style = self.style(w, impl=impl)
1004
+ elif space_of_latent == 'Y':
1005
+ if w.ndim != 2 or w.shape[1] < self.in_channels:
1006
+ raise ValueError(f'The input tensor should be with shape '
1007
+ f'[batch_size, y_dim], where '
1008
+ f'`y_dim` equals to {self.in_channels}!\n'
1009
+ f'But `{w.shape}` is received!')
1010
+ style = w[:, :self.in_channels]
1011
+ else:
1012
+ raise NotImplementedError(f'Not implemented `space_of_latent`: '
1013
+ f'`{space_of_latent}`!')
1014
+ return style
1015
+
1016
+ def forward(self, x, w, gain=None, impl='cuda'):
1017
+ dtype = x.dtype
1018
+ N, C, H, W = x.shape
1019
+
1020
+ # Affine on `w`.
1021
+ style = self.forward_style(w, impl=impl)
1022
+ if not self.demodulate:
1023
+ _style = style * self.wscale # Equivalent to scaling weight.
1024
+ else:
1025
+ _style = style
1026
+
1027
+ weight = self.weight
1028
+ out_ch, in_ch, kh, kw = weight.shape
1029
+ assert in_ch == C
1030
+
1031
+ # Pre-normalize inputs.
1032
+ if self.demodulate:
1033
+ weight = (weight *
1034
+ weight.square().mean(dim=(1, 2, 3), keepdim=True).rsqrt())
1035
+ _style = _style * _style.square().mean().rsqrt()
1036
+
1037
+ weight = weight.unsqueeze(0)
1038
+ weight = weight * _style.reshape(N, 1, in_ch, 1, 1) # modulation
1039
+ if self.demodulate:
1040
+ decoef = (weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt()
1041
+ weight = weight * decoef.reshape(N, out_ch, 1, 1, 1) # demodulation
1042
+
1043
+ if gain is not None:
1044
+ gain = gain.expand(N, in_ch)
1045
+ weight = weight * gain.reshape(N, 1, in_ch, 1, 1)
1046
+
1047
+ # Fuse `conv` and `style modulation` as one op, using group convolution.
1048
+ x = x.reshape(1, N * in_ch, H, W)
1049
+ w = weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype)
1050
+ x = conv2d_gradfix.conv2d(
1051
+ x, w, padding=self.padding, groups=N, impl=impl)
1052
+ x = x.reshape(N, out_ch, x.shape[2], x.shape[3])
1053
+
1054
+ assert x.dtype == dtype
1055
+ assert style.dtype == torch.float32
1056
+ return x, style
1057
+
1058
+
1059
+ class FilteringActLayer(nn.Module):
1060
+ """Implements the activation, wrapped with upsampling and downsampling.
1061
+
1062
+ Basically, this layer executes the following operations in order:
1063
+
1064
+ (1) Apply bias.
1065
+ (2) Upsample the feature map to increase sampling rate.
1066
+ (3) Apply non-linearity as activation.
1067
+ (4) Downsample the feature map to target size.
1068
+
1069
+ This layer is mostly borrowed from the official implementation:
1070
+
1071
+ https://github.com/NVlabs/stylegan3/blob/main/training/networks_stylegan3.py
1072
+ """
1073
+
1074
+ def __init__(self,
1075
+ out_channels,
1076
+ in_size,
1077
+ out_size,
1078
+ in_sampling_rate,
1079
+ out_sampling_rate,
1080
+ in_cutoff,
1081
+ out_cutoff,
1082
+ in_half_width,
1083
+ out_half_width,
1084
+ filter_size,
1085
+ use_radial_filter,
1086
+ conv_padding,
1087
+ act_upsampling):
1088
+ """Initializes with layer settings.
1089
+
1090
+ Args:
1091
+ out_channels: Number of output channels, which is used for `bias`.
1092
+ in_size: Size of the input feature map, i.e., height and width.
1093
+ out_size: Size of the output feature map, i.e., height and width.
1094
+ in_sampling_rate: Sampling rate of the input feature map. Different
1095
+ from `in_size` that includes extended border, this field
1096
+ controls the actual maximum frequency that can be represented
1097
+ by the feature map.
1098
+ out_sampling_rate: Sampling rate of the output feature map.
1099
+ in_cutoff: Cutoff frequency of the input feature map.
1100
+ out_cutoff: Cutoff frequency of the output feature map.
1101
+ in_half_width: Half-width of the transition band of the input
1102
+ feature map.
1103
+ out_half_width: Half-width of the transition band of the output
1104
+ feature map.
1105
+ filter_size: Size of the filter used in this layer.
1106
+ use_radial_filter: Whether to use radial filter.
1107
+ conv_padding: The padding used in the previous convolutional layer.
1108
+ act_upsampling: Upsampling factor used before the activation.
1109
+ `1` means do not wrap upsampling and downsampling around the
1110
+ activation.
1111
+ """
1112
+ super().__init__()
1113
+
1114
+ self.out_channels = out_channels
1115
+ self.in_size = in_size
1116
+ self.out_size = out_size
1117
+ self.in_sampling_rate = in_sampling_rate
1118
+ self.out_sampling_rate = out_sampling_rate
1119
+ self.in_cutoff = in_cutoff
1120
+ self.out_cutoff = out_cutoff
1121
+ self.in_half_width = in_half_width
1122
+ self.out_half_width = out_half_width
1123
+ self.filter_size = filter_size
1124
+ self.use_radial_filter = use_radial_filter
1125
+ self.conv_padding = conv_padding
1126
+ self.act_upsampling = act_upsampling
1127
+
1128
+ # Define bias.
1129
+ self.bias = nn.Parameter(torch.zeros(out_channels))
1130
+
1131
+ # This sampling rate describes the upsampled feature map before
1132
+ # activation.
1133
+ temp_sampling_rate = max(in_sampling_rate, out_sampling_rate)
1134
+ temp_sampling_rate = temp_sampling_rate * act_upsampling
1135
+
1136
+ # Design upsampling filter.
1137
+ up_factor = int(np.rint(temp_sampling_rate / in_sampling_rate))
1138
+ assert in_sampling_rate * up_factor == temp_sampling_rate
1139
+ if up_factor > 1:
1140
+ self.up_factor = up_factor
1141
+ self.up_taps = filter_size * up_factor
1142
+ else:
1143
+ self.up_factor = 1
1144
+ self.up_taps = 1 # No filtering.
1145
+ self.register_buffer(
1146
+ 'up_filter',
1147
+ self.design_lowpass_filter(numtaps=self.up_taps,
1148
+ cutoff=in_cutoff,
1149
+ width=in_half_width * 2,
1150
+ fs=temp_sampling_rate,
1151
+ radial=False))
1152
+
1153
+ # Design downsampling filter.
1154
+ down_factor = int(np.rint(temp_sampling_rate / out_sampling_rate))
1155
+ assert out_sampling_rate * down_factor == temp_sampling_rate
1156
+ if down_factor > 1:
1157
+ self.down_factor = down_factor
1158
+ self.down_taps = filter_size * down_factor
1159
+ else:
1160
+ self.down_factor = 1
1161
+ self.down_taps = 1 # No filtering.
1162
+ self.register_buffer(
1163
+ 'down_filter',
1164
+ self.design_lowpass_filter(numtaps=self.down_taps,
1165
+ cutoff=out_cutoff,
1166
+ width=out_half_width * 2,
1167
+ fs=temp_sampling_rate,
1168
+ radial=use_radial_filter))
1169
+
1170
+ # Compute padding.
1171
+ # Desired output size before downsampling.
1172
+ pad_total = (out_size - 1) * self.down_factor + 1
1173
+ # Input size after upsampling.
1174
+ pad_total = pad_total - (in_size + conv_padding) * self.up_factor
1175
+ # Size reduction caused by the filters.
1176
+ pad_total = pad_total + self.up_taps + self.down_taps - 2
1177
+ # Shift sample locations according to the symmetric interpretation.
1178
+ pad_lo = (pad_total + self.up_factor) // 2
1179
+ pad_hi = pad_total - pad_lo
1180
+ self.padding = list(map(int, (pad_lo, pad_hi, pad_lo, pad_hi)))
1181
+
1182
+ def extra_repr(self):
1183
+ return (f'in_size={self.in_size}, '
1184
+ f'out_size={self.out_size}, '
1185
+ f'in_srate={self.in_sampling_rate}, '
1186
+ f'out_srate={self.out_sampling_rate}, '
1187
+ f'in_cutoff={self.in_cutoff:.3f}, '
1188
+ f'out_cutoff={self.out_cutoff:.3f}, '
1189
+ f'in_half_width={self.in_half_width:.3f}, '
1190
+ f'out_half_width={self.out_half_width:.3f}, '
1191
+ f'up_factor={self.up_factor}, '
1192
+ f'up_taps={self.up_taps}, '
1193
+ f'down_factor={self.down_factor}, '
1194
+ f'down_taps={self.down_taps}, '
1195
+ f'filter_size={self.filter_size}, '
1196
+ f'radial_filter={self.use_radial_filter}, '
1197
+ f'conv_padding={self.conv_padding}, '
1198
+ f'act_upsampling={self.act_upsampling}')
1199
+
1200
+ @staticmethod
1201
+ def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
1202
+ """Designs a low-pass filter.
1203
+
1204
+ Args:
1205
+ numtaps: Length of the filter (number of coefficients, i.e., the
1206
+ filter order + 1).
1207
+ cutoff: Cutoff frequency of the output filter.
1208
+ width: Width of the transition region.
1209
+ fs: Sampling frequency.
1210
+ radial: Whether to use radially symmetric jinc-based filter.
1211
+ (default: False)
1212
+ """
1213
+ if numtaps == 1:
1214
+ return None
1215
+
1216
+ assert numtaps > 1
1217
+
1218
+ if not radial: # Separable Kaiser low-pass filter.
1219
+ f = scipy.signal.firwin(numtaps=numtaps,
1220
+ cutoff=cutoff,
1221
+ width=width,
1222
+ fs=fs)
1223
+ else: # Radially symmetric jinc-based filter.
1224
+ x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
1225
+ r = np.hypot(*np.meshgrid(x, x))
1226
+ f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
1227
+ beta = scipy.signal.kaiser_beta(
1228
+ scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
1229
+ w = np.kaiser(numtaps, beta)
1230
+ f = f * np.outer(w, w)
1231
+ f = f / np.sum(f)
1232
+ return torch.as_tensor(f, dtype=torch.float32)
1233
+
1234
+ def forward(self, x, gain, slope, clamp, impl='cuda'):
1235
+ dtype = x.dtype
1236
+
1237
+ x = filtered_lrelu.filtered_lrelu(x=x,
1238
+ fu=self.up_filter,
1239
+ fd=self.down_filter,
1240
+ b=self.bias.to(dtype),
1241
+ up=self.up_factor,
1242
+ down=self.down_factor,
1243
+ padding=self.padding,
1244
+ gain=gain,
1245
+ slope=slope,
1246
+ clamp=clamp,
1247
+ impl=impl)
1248
+
1249
+ assert x.dtype == dtype
1250
+ return x
1251
+
1252
+
1253
+ class DenseLayer(nn.Module):
1254
+ """Implements the dense layer."""
1255
+
1256
+ def __init__(self,
1257
+ in_channels,
1258
+ out_channels,
1259
+ init_weight_std,
1260
+ add_bias,
1261
+ init_bias,
1262
+ lr_mul,
1263
+ activation_type):
1264
+ """Initializes with layer settings.
1265
+
1266
+ Args:
1267
+ in_channels: Number of channels of the input tensor.
1268
+ out_channels: Number of channels of the output tensor.
1269
+ init_weight_std: The initial standard deviation of weight.
1270
+ add_bias: Whether to add bias onto the fully-connected result.
1271
+ init_bias: The initial bias value before training.
1272
+ lr_mul: Learning multiplier for both weight and bias.
1273
+ activation_type: Type of activation.
1274
+ """
1275
+ super().__init__()
1276
+ self.in_channels = in_channels
1277
+ self.out_channels = out_channels
1278
+ self.init_weight_std = init_weight_std
1279
+ self.add_bias = add_bias
1280
+ self.init_bias = init_bias
1281
+ self.lr_mul = lr_mul
1282
+ self.activation_type = activation_type
1283
+
1284
+ weight_shape = (out_channels, in_channels)
1285
+ self.weight = nn.Parameter(
1286
+ torch.randn(*weight_shape) * init_weight_std / lr_mul)
1287
+ self.wscale = lr_mul / np.sqrt(in_channels)
1288
+
1289
+ if add_bias:
1290
+ init_bias = np.float32(np.float32(init_bias) / lr_mul)
1291
+ if isinstance(init_bias, np.float32):
1292
+ self.bias = nn.Parameter(torch.full([out_channels], init_bias))
1293
+ else:
1294
+ assert isinstance(init_bias, np.ndarray)
1295
+ self.bias = nn.Parameter(torch.from_numpy(init_bias))
1296
+ self.bscale = lr_mul
1297
+ else:
1298
+ self.bias = None
1299
+
1300
+ def extra_repr(self):
1301
+ return (f'in_ch={self.in_channels}, '
1302
+ f'out_ch={self.out_channels}, '
1303
+ f'init_weight_std={self.init_weight_std}, '
1304
+ f'bias={self.add_bias}, '
1305
+ f'init_bias={self.init_bias}, '
1306
+ f'lr_mul={self.lr_mul:.3f}, '
1307
+ f'act={self.activation_type}')
1308
+
1309
+ def forward(self, x, impl='cuda'):
1310
+ dtype = x.dtype
1311
+
1312
+ if x.ndim != 2:
1313
+ x = x.flatten(start_dim=1)
1314
+
1315
+ weight = self.weight.to(dtype) * self.wscale
1316
+ bias = None
1317
+ if self.bias is not None:
1318
+ bias = self.bias.to(dtype)
1319
+ if self.bscale != 1.0:
1320
+ bias = bias * self.bscale
1321
+
1322
+ # Fast pass for linear activation.
1323
+ if self.activation_type == 'linear' and bias is not None:
1324
+ x = torch.addmm(bias.unsqueeze(0), x, weight.t())
1325
+ else:
1326
+ x = x.matmul(weight.t())
1327
+ x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl)
1328
+
1329
+ assert x.dtype == dtype
1330
+ return x
1331
+
1332
+ # pylint: enable=missing-function-docstring
models/stylegan_discriminator.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of discriminator described in StyleGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1812.04948.pdf
5
+
6
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan
7
+ """
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.cuda.amp import autocast
15
+
16
+ __all__ = ['StyleGANDiscriminator']
17
+
18
+ # Resolutions allowed.
19
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
20
+
21
+ # Fused-scale options allowed.
22
+ _FUSED_SCALE_ALLOWED = [True, False, 'auto']
23
+
24
+ # pylint: disable=missing-function-docstring
25
+
26
+ class StyleGANDiscriminator(nn.Module):
27
+ """Defines the discriminator network in StyleGAN.
28
+
29
+ NOTE: The discriminator takes images with `RGB` channel order and pixel
30
+ range [-1, 1] as inputs.
31
+
32
+ Settings for the backbone:
33
+
34
+ (1) resolution: The resolution of the input image. (default: -1)
35
+ (2) init_res: Smallest resolution of the convolutional backbone.
36
+ (default: 4)
37
+ (3) image_channels: Number of channels of the input image. (default: 3)
38
+ (4) fused_scale: The strategy of fusing `conv2d` and `downsample` as one
39
+ operator. `True` means blocks from all resolutions will fuse. `False`
40
+ means blocks from all resolutions will not fuse. `auto` means blocks
41
+ from resolutions higher than (or equal to) `fused_scale_res` will fuse.
42
+ (default: `auto`)
43
+ (5) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample`
44
+ as one operator. This field only takes effect if `fused_scale` is set
45
+ as `auto`. (default: 128)
46
+ (6) use_wscale: Whether to use weight scaling. (default: True)
47
+ (7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
48
+ (8) lr_mul: Learning rate multiplier for backbone. (default: 1.0)
49
+ (9) mbstd_groups: Group size for the minibatch standard deviation layer.
50
+ `0` means disable. (default: 4)
51
+ (10) mbstd_channels: Number of new channels (appended to the original
52
+ feature map) after the minibatch standard deviation layer. (default: 1)
53
+ (11) fmaps_base: Factor to control number of feature maps for each layer.
54
+ (default: 16 << 10)
55
+ (12) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
56
+ (13) filter_kernel: Kernel used for filtering (e.g., downsampling).
57
+ (default: (1, 2, 1))
58
+ (14) eps: A small value to avoid divide overflow. (default: 1e-8)
59
+
60
+ Settings for conditional model:
61
+
62
+ (1) label_dim: Dimension of the additional label for conditional generation.
63
+ In one-hot conditioning case, it is equal to the number of classes. If
64
+ set to 0, conditioning training will be disabled. (default: 0)
65
+
66
+ Runtime settings:
67
+
68
+ (1) enable_amp: Whether to enable automatic mixed precision training.
69
+ (default: False)
70
+ """
71
+
72
+ def __init__(self,
73
+ # Settings for backbone.
74
+ resolution=-1,
75
+ init_res=4,
76
+ image_channels=3,
77
+ fused_scale='auto',
78
+ fused_scale_res=128,
79
+ use_wscale=True,
80
+ wscale_gain=np.sqrt(2.0),
81
+ lr_mul=1.0,
82
+ mbstd_groups=4,
83
+ mbstd_channels=1,
84
+ fmaps_base=16 << 10,
85
+ fmaps_max=512,
86
+ filter_kernel=(1, 2, 1),
87
+ eps=1e-8,
88
+ # Settings for conditional model.
89
+ label_dim=0):
90
+ """Initializes with basic settings.
91
+
92
+ Raises:
93
+ ValueError: If the `resolution` is not supported, or `fused_scale`
94
+ is not supported.
95
+ """
96
+ super().__init__()
97
+
98
+ if resolution not in _RESOLUTIONS_ALLOWED:
99
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
100
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
101
+ if fused_scale not in _FUSED_SCALE_ALLOWED:
102
+ raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
103
+ f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
104
+
105
+ self.init_res = init_res
106
+ self.init_res_log2 = int(np.log2(init_res))
107
+ self.resolution = resolution
108
+ self.final_res_log2 = int(np.log2(resolution))
109
+ self.image_channels = image_channels
110
+ self.fused_scale = fused_scale
111
+ self.fused_scale_res = fused_scale_res
112
+ self.use_wscale = use_wscale
113
+ self.wscale_gain = wscale_gain
114
+ self.lr_mul = lr_mul
115
+ self.mbstd_groups = mbstd_groups
116
+ self.mbstd_channels = mbstd_channels
117
+ self.fmaps_base = fmaps_base
118
+ self.fmaps_max = fmaps_max
119
+ self.filter_kernel = filter_kernel
120
+ self.eps = eps
121
+ self.label_dim = label_dim
122
+
123
+ # Level-of-details (used for progressive training).
124
+ self.register_buffer('lod', torch.zeros(()))
125
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
126
+
127
+ for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
128
+ res = 2 ** res_log2
129
+ in_channels = self.get_nf(res)
130
+ out_channels = self.get_nf(res // 2)
131
+ block_idx = self.final_res_log2 - res_log2
132
+
133
+ # Input convolution layer for each resolution.
134
+ self.add_module(
135
+ f'input{block_idx}',
136
+ ConvLayer(in_channels=image_channels,
137
+ out_channels=in_channels,
138
+ kernel_size=1,
139
+ add_bias=True,
140
+ scale_factor=1,
141
+ fused_scale=False,
142
+ filter_kernel=None,
143
+ use_wscale=use_wscale,
144
+ wscale_gain=wscale_gain,
145
+ lr_mul=lr_mul,
146
+ activation_type='lrelu'))
147
+ self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
148
+ f'FromRGB_lod{block_idx}/weight')
149
+ self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
150
+ f'FromRGB_lod{block_idx}/bias')
151
+
152
+ # Convolution block for each resolution (except the last one).
153
+ if res != self.init_res:
154
+ # First layer (kernel 3x3) without downsampling.
155
+ layer_name = f'layer{2 * block_idx}'
156
+ self.add_module(
157
+ layer_name,
158
+ ConvLayer(in_channels=in_channels,
159
+ out_channels=in_channels,
160
+ kernel_size=3,
161
+ add_bias=True,
162
+ scale_factor=1,
163
+ fused_scale=False,
164
+ filter_kernel=None,
165
+ use_wscale=use_wscale,
166
+ wscale_gain=wscale_gain,
167
+ lr_mul=lr_mul,
168
+ activation_type='lrelu'))
169
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
170
+ f'{res}x{res}/Conv0/weight')
171
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
172
+ f'{res}x{res}/Conv0/bias')
173
+
174
+ # Second layer (kernel 3x3) with downsampling
175
+ layer_name = f'layer{2 * block_idx + 1}'
176
+ self.add_module(
177
+ layer_name,
178
+ ConvLayer(in_channels=in_channels,
179
+ out_channels=out_channels,
180
+ kernel_size=3,
181
+ add_bias=True,
182
+ scale_factor=2,
183
+ fused_scale=(res >= fused_scale_res
184
+ if fused_scale == 'auto'
185
+ else fused_scale),
186
+ filter_kernel=filter_kernel,
187
+ use_wscale=use_wscale,
188
+ wscale_gain=wscale_gain,
189
+ lr_mul=lr_mul,
190
+ activation_type='lrelu'))
191
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
192
+ f'{res}x{res}/Conv1_down/weight')
193
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
194
+ f'{res}x{res}/Conv1_down/bias')
195
+
196
+ # Convolution block for last resolution.
197
+ else:
198
+ self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups,
199
+ new_channels=mbstd_channels,
200
+ eps=eps)
201
+
202
+ # First layer (kernel 3x3) without downsampling.
203
+ layer_name = f'layer{2 * block_idx}'
204
+ self.add_module(
205
+ layer_name,
206
+ ConvLayer(in_channels=in_channels + mbstd_channels,
207
+ out_channels=in_channels,
208
+ kernel_size=3,
209
+ add_bias=True,
210
+ scale_factor=1,
211
+ fused_scale=False,
212
+ filter_kernel=None,
213
+ use_wscale=use_wscale,
214
+ wscale_gain=wscale_gain,
215
+ lr_mul=lr_mul,
216
+ activation_type='lrelu'))
217
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
218
+ f'{res}x{res}/Conv/weight')
219
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
220
+ f'{res}x{res}/Conv/bias')
221
+
222
+ # Second layer, as a fully-connected layer.
223
+ layer_name = f'layer{2 * block_idx + 1}'
224
+ self.add_module(
225
+ f'layer{2 * block_idx + 1}',
226
+ DenseLayer(in_channels=in_channels * res * res,
227
+ out_channels=in_channels,
228
+ add_bias=True,
229
+ use_wscale=use_wscale,
230
+ wscale_gain=wscale_gain,
231
+ lr_mul=lr_mul,
232
+ activation_type='lrelu'))
233
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
234
+ f'{res}x{res}/Dense0/weight')
235
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
236
+ f'{res}x{res}/Dense0/bias')
237
+
238
+ # Final dense layer to output score.
239
+ self.output = DenseLayer(in_channels=in_channels,
240
+ out_channels=max(label_dim, 1),
241
+ add_bias=True,
242
+ use_wscale=use_wscale,
243
+ wscale_gain=1.0,
244
+ lr_mul=lr_mul,
245
+ activation_type='linear')
246
+ self.pth_to_tf_var_mapping['output.weight'] = (
247
+ f'{res}x{res}/Dense1/weight')
248
+ self.pth_to_tf_var_mapping['output.bias'] = (
249
+ f'{res}x{res}/Dense1/bias')
250
+
251
+ def get_nf(self, res):
252
+ """Gets number of feature maps according to the given resolution."""
253
+ return min(self.fmaps_base // res, self.fmaps_max)
254
+
255
+ def forward(self, image, label=None, lod=None, enable_amp=False):
256
+ expected_shape = (self.image_channels, self.resolution, self.resolution)
257
+ if image.ndim != 4 or image.shape[1:] != expected_shape:
258
+ raise ValueError(f'The input tensor should be with shape '
259
+ f'[batch_size, channel, height, width], where '
260
+ f'`channel` equals to {self.image_channels}, '
261
+ f'`height`, `width` equal to {self.resolution}!\n'
262
+ f'But `{image.shape}` is received!')
263
+
264
+ lod = self.lod.item() if lod is None else lod
265
+ if lod + self.init_res_log2 > self.final_res_log2:
266
+ raise ValueError(f'Maximum level-of-details (lod) is '
267
+ f'{self.final_res_log2 - self.init_res_log2}, '
268
+ f'but `{lod}` is received!')
269
+
270
+ if self.label_dim:
271
+ if label is None:
272
+ raise ValueError(f'Model requires an additional label '
273
+ f'(with dimension {self.label_dim}) as input, '
274
+ f'but no label is received!')
275
+ batch = image.shape[0]
276
+ if (label.ndim != 2 or label.shape != (batch, self.label_dim)):
277
+ raise ValueError(f'Input label should be with shape '
278
+ f'[batch_size, label_dim], where '
279
+ f'`batch_size` equals to {batch}, and '
280
+ f'`label_dim` equals to {self.label_dim}!\n'
281
+ f'But `{label.shape}` is received!')
282
+ label = label.to(dtype=torch.float32)
283
+
284
+ with autocast(enabled=enable_amp):
285
+ for res_log2 in range(
286
+ self.final_res_log2, self.init_res_log2 - 1, -1):
287
+ block_idx = current_lod = self.final_res_log2 - res_log2
288
+ if current_lod <= lod < current_lod + 1:
289
+ x = getattr(self, f'input{block_idx}')(image)
290
+ elif current_lod - 1 < lod < current_lod:
291
+ alpha = lod - np.floor(lod)
292
+ y = getattr(self, f'input{block_idx}')(image)
293
+ x = y * alpha + x * (1 - alpha)
294
+ if lod < current_lod + 1:
295
+ if res_log2 == self.init_res_log2:
296
+ x = self.mbstd(x)
297
+ x = getattr(self, f'layer{2 * block_idx}')(x)
298
+ x = getattr(self, f'layer{2 * block_idx + 1}')(x)
299
+ if lod > current_lod:
300
+ image = F.avg_pool2d(
301
+ image, kernel_size=2, stride=2, padding=0)
302
+ x = self.output(x)
303
+
304
+ if self.label_dim:
305
+ x = (x * label).sum(dim=1, keepdim=True)
306
+
307
+ results = {
308
+ 'score': x,
309
+ 'label': label
310
+ }
311
+ return results
312
+
313
+
314
+ class MiniBatchSTDLayer(nn.Module):
315
+ """Implements the minibatch standard deviation layer."""
316
+
317
+ def __init__(self, groups, new_channels, eps):
318
+ super().__init__()
319
+ self.groups = groups
320
+ self.new_channels = new_channels
321
+ self.eps = eps
322
+
323
+ def extra_repr(self):
324
+ return (f'groups={self.groups}, '
325
+ f'new_channels={self.new_channels}, '
326
+ f'epsilon={self.eps}')
327
+
328
+ def forward(self, x):
329
+ if self.groups <= 1 or self.new_channels < 1:
330
+ return x
331
+
332
+ N, C, H, W = x.shape
333
+ G = min(self.groups, N) # Number of groups.
334
+ nC = self.new_channels # Number of channel groups.
335
+ c = C // nC # Channels per channel group.
336
+
337
+ y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW]
338
+ y = y - y.mean(dim=0) # [GnFcHW]
339
+ y = y.square().mean(dim=0) # [nFcHW]
340
+ y = (y + self.eps).sqrt() # [nFcHW]
341
+ y = y.mean(dim=(2, 3, 4)) # [nF]
342
+ y = y.reshape(-1, nC, 1, 1) # [nF11]
343
+ y = y.repeat(G, 1, H, W) # [NFHW]
344
+ x = torch.cat((x, y), dim=1) # [N(C+F)HW]
345
+
346
+ return x
347
+
348
+
349
+ class Blur(torch.autograd.Function):
350
+ """Defines blur operation with customized gradient computation."""
351
+
352
+ @staticmethod
353
+ def forward(ctx, x, kernel):
354
+ assert kernel.shape[2] == 3 and kernel.shape[3] == 3
355
+ ctx.save_for_backward(kernel)
356
+ y = F.conv2d(input=x,
357
+ weight=kernel,
358
+ bias=None,
359
+ stride=1,
360
+ padding=1,
361
+ groups=x.shape[1])
362
+ return y
363
+
364
+ @staticmethod
365
+ def backward(ctx, dy):
366
+ kernel, = ctx.saved_tensors
367
+ dx = BlurBackPropagation.apply(dy, kernel)
368
+ return dx, None, None
369
+
370
+
371
+ class BlurBackPropagation(torch.autograd.Function):
372
+ """Defines the back propagation of blur operation.
373
+
374
+ NOTE: This is used to speed up the backward of gradient penalty.
375
+ """
376
+
377
+ @staticmethod
378
+ def forward(ctx, dy, kernel):
379
+ ctx.save_for_backward(kernel)
380
+ dx = F.conv2d(input=dy,
381
+ weight=kernel.flip((2, 3)),
382
+ bias=None,
383
+ stride=1,
384
+ padding=1,
385
+ groups=dy.shape[1])
386
+ return dx
387
+
388
+ @staticmethod
389
+ def backward(ctx, ddx):
390
+ kernel, = ctx.saved_tensors
391
+ ddy = F.conv2d(input=ddx,
392
+ weight=kernel,
393
+ bias=None,
394
+ stride=1,
395
+ padding=1,
396
+ groups=ddx.shape[1])
397
+ return ddy, None, None
398
+
399
+
400
+ class ConvLayer(nn.Module):
401
+ """Implements the convolutional layer.
402
+
403
+ If downsampling is needed (i.e., `scale_factor = 2`), the feature map will
404
+ be filtered with `filter_kernel` first. If `fused_scale` is set as `True`,
405
+ `conv2d` and `downsample` will be fused as one operator, using stride
406
+ convolution.
407
+ """
408
+
409
+ def __init__(self,
410
+ in_channels,
411
+ out_channels,
412
+ kernel_size,
413
+ add_bias,
414
+ scale_factor,
415
+ fused_scale,
416
+ filter_kernel,
417
+ use_wscale,
418
+ wscale_gain,
419
+ lr_mul,
420
+ activation_type):
421
+ """Initializes with layer settings.
422
+
423
+ Args:
424
+ in_channels: Number of channels of the input tensor.
425
+ out_channels: Number of channels of the output tensor.
426
+ kernel_size: Size of the convolutional kernels.
427
+ add_bias: Whether to add bias onto the convolutional result.
428
+ scale_factor: Scale factor for downsampling. `1` means skip
429
+ downsampling.
430
+ fused_scale: Whether to fuse `conv2d` and `downsample` as one
431
+ operator, using stride convolution.
432
+ filter_kernel: Kernel used for filtering.
433
+ use_wscale: Whether to use weight scaling.
434
+ wscale_gain: Gain factor for weight scaling.
435
+ lr_mul: Learning multiplier for both weight and bias.
436
+ activation_type: Type of activation.
437
+ """
438
+ super().__init__()
439
+ self.in_channels = in_channels
440
+ self.out_channels = out_channels
441
+ self.kernel_size = kernel_size
442
+ self.add_bias = add_bias
443
+ self.scale_factor = scale_factor
444
+ self.fused_scale = fused_scale
445
+ self.filter_kernel = filter_kernel
446
+ self.use_wscale = use_wscale
447
+ self.wscale_gain = wscale_gain
448
+ self.lr_mul = lr_mul
449
+ self.activation_type = activation_type
450
+
451
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
452
+ fan_in = kernel_size * kernel_size * in_channels
453
+ wscale = wscale_gain / np.sqrt(fan_in)
454
+ if use_wscale:
455
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
456
+ self.wscale = wscale * lr_mul
457
+ else:
458
+ self.weight = nn.Parameter(
459
+ torch.randn(*weight_shape) * wscale / lr_mul)
460
+ self.wscale = lr_mul
461
+
462
+ if add_bias:
463
+ self.bias = nn.Parameter(torch.zeros(out_channels))
464
+ self.bscale = lr_mul
465
+ else:
466
+ self.bias = None
467
+
468
+ if scale_factor > 1:
469
+ assert filter_kernel is not None
470
+ kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1)
471
+ kernel = kernel.T.dot(kernel)
472
+ kernel = kernel / np.sum(kernel)
473
+ kernel = kernel[np.newaxis, np.newaxis]
474
+ self.register_buffer('filter', torch.from_numpy(kernel))
475
+
476
+ if scale_factor > 1 and fused_scale: # use stride convolution.
477
+ self.stride = scale_factor
478
+ else:
479
+ self.stride = 1
480
+ self.padding = kernel_size // 2
481
+
482
+ assert activation_type in ['linear', 'relu', 'lrelu']
483
+
484
+ def extra_repr(self):
485
+ return (f'in_ch={self.in_channels}, '
486
+ f'out_ch={self.out_channels}, '
487
+ f'ksize={self.kernel_size}, '
488
+ f'wscale_gain={self.wscale_gain:.3f}, '
489
+ f'bias={self.add_bias}, '
490
+ f'lr_mul={self.lr_mul:.3f}, '
491
+ f'downsample={self.scale_factor}, '
492
+ f'fused_scale={self.fused_scale}, '
493
+ f'downsample_filter={self.filter_kernel}, '
494
+ f'act={self.activation_type}')
495
+
496
+ def forward(self, x):
497
+ if self.scale_factor > 1:
498
+ # Disable `autocast` for customized autograd function.
499
+ # Please check reference:
500
+ # https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions
501
+ with autocast(enabled=False):
502
+ f = self.filter.repeat(self.in_channels, 1, 1, 1)
503
+ x = Blur.apply(x.float(), f) # Always use FP32.
504
+
505
+ weight = self.weight
506
+ if self.wscale != 1.0:
507
+ weight = weight * self.wscale
508
+ bias = None
509
+ if self.bias is not None:
510
+ bias = self.bias
511
+ if self.bscale != 1.0:
512
+ bias = bias * self.bscale
513
+
514
+ if self.scale_factor > 1 and self.fused_scale:
515
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
516
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
517
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
518
+ x = F.conv2d(x,
519
+ weight=weight,
520
+ bias=bias,
521
+ stride=self.stride,
522
+ padding=self.padding)
523
+ if self.scale_factor > 1 and not self.fused_scale:
524
+ down = self.scale_factor
525
+ x = F.avg_pool2d(x, kernel_size=down, stride=down, padding=0)
526
+
527
+ if self.activation_type == 'linear':
528
+ pass
529
+ elif self.activation_type == 'relu':
530
+ x = F.relu(x, inplace=True)
531
+ elif self.activation_type == 'lrelu':
532
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
533
+ else:
534
+ raise NotImplementedError(f'Not implemented activation type '
535
+ f'`{self.activation_type}`!')
536
+
537
+ return x
538
+
539
+
540
+ class DenseLayer(nn.Module):
541
+ """Implements the dense layer."""
542
+
543
+ def __init__(self,
544
+ in_channels,
545
+ out_channels,
546
+ add_bias,
547
+ use_wscale,
548
+ wscale_gain,
549
+ lr_mul,
550
+ activation_type):
551
+ """Initializes with layer settings.
552
+
553
+ Args:
554
+ in_channels: Number of channels of the input tensor.
555
+ out_channels: Number of channels of the output tensor.
556
+ add_bias: Whether to add bias onto the fully-connected result.
557
+ use_wscale: Whether to use weight scaling.
558
+ wscale_gain: Gain factor for weight scaling.
559
+ lr_mul: Learning multiplier for both weight and bias.
560
+ activation_type: Type of activation.
561
+ """
562
+ super().__init__()
563
+ self.in_channels = in_channels
564
+ self.out_channels = out_channels
565
+ self.add_bias = add_bias
566
+ self.use_wscale = use_wscale
567
+ self.wscale_gain = wscale_gain
568
+ self.lr_mul = lr_mul
569
+ self.activation_type = activation_type
570
+
571
+ weight_shape = (out_channels, in_channels)
572
+ wscale = wscale_gain / np.sqrt(in_channels)
573
+ if use_wscale:
574
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
575
+ self.wscale = wscale * lr_mul
576
+ else:
577
+ self.weight = nn.Parameter(
578
+ torch.randn(*weight_shape) * wscale / lr_mul)
579
+ self.wscale = lr_mul
580
+
581
+ if add_bias:
582
+ self.bias = nn.Parameter(torch.zeros(out_channels))
583
+ self.bscale = lr_mul
584
+ else:
585
+ self.bias = None
586
+
587
+ assert activation_type in ['linear', 'relu', 'lrelu']
588
+
589
+ def extra_repr(self):
590
+ return (f'in_ch={self.in_channels}, '
591
+ f'out_ch={self.out_channels}, '
592
+ f'wscale_gain={self.wscale_gain:.3f}, '
593
+ f'bias={self.add_bias}, '
594
+ f'lr_mul={self.lr_mul:.3f}, '
595
+ f'act={self.activation_type}')
596
+
597
+ def forward(self, x):
598
+ if x.ndim != 2:
599
+ x = x.flatten(start_dim=1)
600
+
601
+ weight = self.weight
602
+ if self.wscale != 1.0:
603
+ weight = weight * self.wscale
604
+ bias = None
605
+ if self.bias is not None:
606
+ bias = self.bias
607
+ if self.bscale != 1.0:
608
+ bias = bias * self.bscale
609
+
610
+ x = F.linear(x, weight=weight, bias=bias)
611
+
612
+ if self.activation_type == 'linear':
613
+ pass
614
+ elif self.activation_type == 'relu':
615
+ x = F.relu(x, inplace=True)
616
+ elif self.activation_type == 'lrelu':
617
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
618
+ else:
619
+ raise NotImplementedError(f'Not implemented activation type '
620
+ f'`{self.activation_type}`!')
621
+
622
+ return x
623
+
624
+ # pylint: enable=missing-function-docstring
models/stylegan_generator.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the implementation of generator described in StyleGAN.
3
+
4
+ Paper: https://arxiv.org/pdf/1812.04948.pdf
5
+
6
+ Official TensorFlow implementation: https://github.com/NVlabs/stylegan
7
+ """
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.cuda.amp import autocast
15
+
16
+ from .utils.ops import all_gather
17
+
18
+ __all__ = ['StyleGANGenerator']
19
+
20
+ # Resolutions allowed.
21
+ _RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
22
+
23
+ # Fused-scale options allowed.
24
+ _FUSED_SCALE_ALLOWED = [True, False, 'auto']
25
+
26
+ # pylint: disable=missing-function-docstring
27
+
28
+ class StyleGANGenerator(nn.Module):
29
+ """Defines the generator network in StyleGAN.
30
+
31
+ NOTE: The synthesized images are with `RGB` channel order and pixel range
32
+ [-1, 1].
33
+
34
+ Settings for the mapping network:
35
+
36
+ (1) z_dim: Dimension of the input latent space, Z. (default: 512)
37
+ (2) w_dim: Dimension of the output latent space, W. (default: 512)
38
+ (3) repeat_w: Repeat w-code for different layers. (default: True)
39
+ (4) normalize_z: Whether to normalize the z-code. (default: True)
40
+ (5) mapping_layers: Number of layers of the mapping network. (default: 8)
41
+ (6) mapping_fmaps: Number of hidden channels of the mapping network.
42
+ (default: 512)
43
+ (7) mapping_use_wscale: Whether to use weight scaling for the mapping
44
+ network. (default: True)
45
+ (8) mapping_wscale_gain: The factor to control weight scaling for the
46
+ mapping network (default: sqrt(2.0))
47
+ (9) mapping_lr_mul: Learning rate multiplier for the mapping network.
48
+ (default: 0.01)
49
+
50
+ Settings for conditional generation:
51
+
52
+ (1) label_dim: Dimension of the additional label for conditional generation.
53
+ In one-hot conditioning case, it is equal to the number of classes. If
54
+ set to 0, conditioning training will be disabled. (default: 0)
55
+ (2) embedding_dim: Dimension of the embedding space, if needed.
56
+ (default: 512)
57
+
58
+ Settings for the synthesis network:
59
+
60
+ (1) resolution: The resolution of the output image. (default: -1)
61
+ (2) init_res: The initial resolution to start with convolution. (default: 4)
62
+ (3) image_channels: Number of channels of the output image. (default: 3)
63
+ (4) final_tanh: Whether to use `tanh` to control the final pixel range.
64
+ (default: False)
65
+ (5) fused_scale: The strategy of fusing `upsample` and `conv2d` as one
66
+ operator. `True` means blocks from all resolutions will fuse. `False`
67
+ means blocks from all resolutions will not fuse. `auto` means blocks
68
+ from resolutions higher than (or equal to) `fused_scale_res` will fuse.
69
+ (default: `auto`)
70
+ (6) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample`
71
+ as one operator. This field only takes effect if `fused_scale` is set
72
+ as `auto`. (default: 128)
73
+ (7) use_wscale: Whether to use weight scaling. (default: True)
74
+ (8) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
75
+ (9) lr_mul: Learning rate multiplier for the synthesis network.
76
+ (default: 1.0)
77
+ (10) noise_type: Type of noise added to the convolutional results at each
78
+ layer. (default: `spatial`)
79
+ (11) fmaps_base: Factor to control number of feature maps for each layer.
80
+ (default: 16 << 10)
81
+ (12) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
82
+ (13) filter_kernel: Kernel used for filtering (e.g., downsampling).
83
+ (default: (1, 2, 1))
84
+ (14) eps: A small value to avoid divide overflow. (default: 1e-8)
85
+
86
+ Runtime settings:
87
+
88
+ (1) w_moving_decay: Decay factor for updating `w_avg`, which is used for
89
+ training only. Set `None` to disable. (default: None)
90
+ (2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set
91
+ as `True`, the stats will be more accurate, yet the speed maybe a little
92
+ bit slower. (default: False)
93
+ (3) style_mixing_prob: Probability to perform style mixing as a training
94
+ regularization. Set `None` to disable. (default: None)
95
+ (4) trunc_psi: Truncation psi, set `None` to disable. (default: None)
96
+ (5) trunc_layers: Number of layers to perform truncation. (default: None)
97
+ (6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`,
98
+ `const`. (default: `const`)
99
+ (7) enable_amp: Whether to enable automatic mixed precision training.
100
+ (default: False)
101
+ """
102
+
103
+ def __init__(self,
104
+ # Settings for mapping network.
105
+ z_dim=512,
106
+ w_dim=512,
107
+ repeat_w=True,
108
+ normalize_z=True,
109
+ mapping_layers=8,
110
+ mapping_fmaps=512,
111
+ mapping_use_wscale=True,
112
+ mapping_wscale_gain=np.sqrt(2.0),
113
+ mapping_lr_mul=0.01,
114
+ # Settings for conditional generation.
115
+ label_dim=0,
116
+ embedding_dim=512,
117
+ # Settings for synthesis network.
118
+ resolution=-1,
119
+ init_res=4,
120
+ image_channels=3,
121
+ final_tanh=False,
122
+ fused_scale='auto',
123
+ fused_scale_res=128,
124
+ use_wscale=True,
125
+ wscale_gain=np.sqrt(2.0),
126
+ lr_mul=1.0,
127
+ noise_type='spatial',
128
+ fmaps_base=16 << 10,
129
+ fmaps_max=512,
130
+ filter_kernel=(1, 2, 1),
131
+ eps=1e-8):
132
+ """Initializes with basic settings.
133
+
134
+ Raises:
135
+ ValueError: If the `resolution` is not supported, or `fused_scale`
136
+ is not supported.
137
+ """
138
+ super().__init__()
139
+
140
+ if resolution not in _RESOLUTIONS_ALLOWED:
141
+ raise ValueError(f'Invalid resolution: `{resolution}`!\n'
142
+ f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
143
+ if fused_scale not in _FUSED_SCALE_ALLOWED:
144
+ raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
145
+ f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
146
+
147
+ self.z_dim = z_dim
148
+ self.w_dim = w_dim
149
+ self.repeat_w = repeat_w
150
+ self.normalize_z = normalize_z
151
+ self.mapping_layers = mapping_layers
152
+ self.mapping_fmaps = mapping_fmaps
153
+ self.mapping_use_wscale = mapping_use_wscale
154
+ self.mapping_wscale_gain = mapping_wscale_gain
155
+ self.mapping_lr_mul = mapping_lr_mul
156
+
157
+ self.label_dim = label_dim
158
+ self.embedding_dim = embedding_dim
159
+
160
+ self.resolution = resolution
161
+ self.init_res = init_res
162
+ self.image_channels = image_channels
163
+ self.final_tanh = final_tanh
164
+ self.fused_scale = fused_scale
165
+ self.fused_scale_res = fused_scale_res
166
+ self.use_wscale = use_wscale
167
+ self.wscale_gain = wscale_gain
168
+ self.lr_mul = lr_mul
169
+ self.noise_type = noise_type.lower()
170
+ self.fmaps_base = fmaps_base
171
+ self.fmaps_max = fmaps_max
172
+ self.filter_kernel = filter_kernel
173
+ self.eps = eps
174
+
175
+ # Dimension of latent space, which is convenient for sampling.
176
+ self.latent_dim = (z_dim,)
177
+
178
+ # Number of synthesis (convolutional) layers.
179
+ self.num_layers = int(np.log2(resolution // init_res * 2)) * 2
180
+
181
+ self.mapping = MappingNetwork(input_dim=z_dim,
182
+ output_dim=w_dim,
183
+ num_outputs=self.num_layers,
184
+ repeat_output=repeat_w,
185
+ normalize_input=normalize_z,
186
+ num_layers=mapping_layers,
187
+ hidden_dim=mapping_fmaps,
188
+ use_wscale=mapping_use_wscale,
189
+ wscale_gain=mapping_wscale_gain,
190
+ lr_mul=mapping_lr_mul,
191
+ label_dim=label_dim,
192
+ embedding_dim=embedding_dim,
193
+ eps=eps)
194
+
195
+ # This is used for truncation trick.
196
+ if self.repeat_w:
197
+ self.register_buffer('w_avg', torch.zeros(w_dim))
198
+ else:
199
+ self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim))
200
+
201
+ self.synthesis = SynthesisNetwork(resolution=resolution,
202
+ init_res=init_res,
203
+ w_dim=w_dim,
204
+ image_channels=image_channels,
205
+ final_tanh=final_tanh,
206
+ fused_scale=fused_scale,
207
+ fused_scale_res=fused_scale_res,
208
+ use_wscale=use_wscale,
209
+ wscale_gain=wscale_gain,
210
+ lr_mul=lr_mul,
211
+ noise_type=noise_type,
212
+ fmaps_base=fmaps_base,
213
+ fmaps_max=fmaps_max,
214
+ filter_kernel=filter_kernel,
215
+ eps=eps)
216
+
217
+ self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
218
+ for key, val in self.mapping.pth_to_tf_var_mapping.items():
219
+ self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
220
+ for key, val in self.synthesis.pth_to_tf_var_mapping.items():
221
+ self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
222
+
223
+ def set_space_of_latent(self, space_of_latent):
224
+ """Sets the space to which the latent code belong.
225
+
226
+ See `SynthesisNetwork` for more details.
227
+ """
228
+ self.synthesis.set_space_of_latent(space_of_latent)
229
+
230
+ def forward(self,
231
+ z,
232
+ label=None,
233
+ lod=None,
234
+ w_moving_decay=None,
235
+ sync_w_avg=False,
236
+ style_mixing_prob=None,
237
+ trunc_psi=None,
238
+ trunc_layers=None,
239
+ noise_mode='const',
240
+ enable_amp=False):
241
+ mapping_results = self.mapping(z, label)
242
+
243
+ w = mapping_results['w']
244
+ if self.training and w_moving_decay is not None:
245
+ if sync_w_avg:
246
+ batch_w_avg = all_gather(w.detach()).mean(dim=0)
247
+ else:
248
+ batch_w_avg = w.detach().mean(dim=0)
249
+ self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay))
250
+
251
+ wp = mapping_results.pop('wp')
252
+ if self.training and style_mixing_prob is not None:
253
+ if np.random.uniform() < style_mixing_prob:
254
+ new_z = torch.randn_like(z)
255
+ new_wp = self.mapping(new_z, label)['wp']
256
+ lod = self.synthesis.lod.item() if lod is None else lod
257
+ current_layers = self.num_layers - int(lod) * 2
258
+ mixing_cutoff = np.random.randint(1, current_layers)
259
+ wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:]
260
+
261
+ if not self.training:
262
+ trunc_psi = 1.0 if trunc_psi is None else trunc_psi
263
+ trunc_layers = 0 if trunc_layers is None else trunc_layers
264
+ if trunc_psi < 1.0 and trunc_layers > 0:
265
+ w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers]
266
+ wp[:, :trunc_layers] = w_avg.lerp(
267
+ wp[:, :trunc_layers], trunc_psi)
268
+
269
+ with autocast(enabled=enable_amp):
270
+ synthesis_results = self.synthesis(wp,
271
+ lod=lod,
272
+ noise_mode=noise_mode)
273
+
274
+ return {**mapping_results, **synthesis_results}
275
+
276
+
277
+ class MappingNetwork(nn.Module):
278
+ """Implements the latent space mapping module.
279
+
280
+ Basically, this module executes several dense layers in sequence, and the
281
+ label embedding if needed.
282
+ """
283
+
284
+ def __init__(self,
285
+ input_dim,
286
+ output_dim,
287
+ num_outputs,
288
+ repeat_output,
289
+ normalize_input,
290
+ num_layers,
291
+ hidden_dim,
292
+ use_wscale,
293
+ wscale_gain,
294
+ lr_mul,
295
+ label_dim,
296
+ embedding_dim,
297
+ eps):
298
+ super().__init__()
299
+
300
+ self.input_dim = input_dim
301
+ self.output_dim = output_dim
302
+ self.num_outputs = num_outputs
303
+ self.repeat_output = repeat_output
304
+ self.normalize_input = normalize_input
305
+ self.num_layers = num_layers
306
+ self.hidden_dim = hidden_dim
307
+ self.use_wscale = use_wscale
308
+ self.wscale_gain = wscale_gain
309
+ self.lr_mul = lr_mul
310
+ self.label_dim = label_dim
311
+ self.embedding_dim = embedding_dim
312
+ self.eps = eps
313
+
314
+ self.pth_to_tf_var_mapping = {}
315
+
316
+ if normalize_input:
317
+ self.norm = PixelNormLayer(dim=1, eps=eps)
318
+
319
+ if self.label_dim > 0:
320
+ input_dim = input_dim + embedding_dim
321
+ self.embedding = nn.Parameter(
322
+ torch.randn(label_dim, embedding_dim))
323
+ self.pth_to_tf_var_mapping['embedding'] = 'LabelConcat/weight'
324
+
325
+ if num_outputs is not None and not repeat_output:
326
+ output_dim = output_dim * num_outputs
327
+ for i in range(num_layers):
328
+ in_channels = (input_dim if i == 0 else hidden_dim)
329
+ out_channels = (output_dim if i == (num_layers - 1) else hidden_dim)
330
+ self.add_module(f'dense{i}',
331
+ DenseLayer(in_channels=in_channels,
332
+ out_channels=out_channels,
333
+ add_bias=True,
334
+ use_wscale=use_wscale,
335
+ wscale_gain=wscale_gain,
336
+ lr_mul=lr_mul,
337
+ activation_type='lrelu'))
338
+ self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
339
+ self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
340
+
341
+ def forward(self, z, label=None):
342
+ if z.ndim != 2 or z.shape[1] != self.input_dim:
343
+ raise ValueError(f'Input latent code should be with shape '
344
+ f'[batch_size, input_dim], where '
345
+ f'`input_dim` equals to {self.input_dim}!\n'
346
+ f'But `{z.shape}` is received!')
347
+
348
+ if self.label_dim > 0:
349
+ if label is None:
350
+ raise ValueError(f'Model requires an additional label '
351
+ f'(with dimension {self.label_dim}) as input, '
352
+ f'but no label is received!')
353
+ if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
354
+ raise ValueError(f'Input label should be with shape '
355
+ f'[batch_size, label_dim], where '
356
+ f'`batch_size` equals to that of '
357
+ f'latent codes ({z.shape[0]}) and '
358
+ f'`label_dim` equals to {self.label_dim}!\n'
359
+ f'But `{label.shape}` is received!')
360
+ label = label.to(dtype=torch.float32)
361
+ embedding = torch.matmul(label, self.embedding)
362
+ z = torch.cat((z, embedding), dim=1)
363
+
364
+ if self.normalize_input:
365
+ w = self.norm(z)
366
+ else:
367
+ w = z
368
+
369
+ for i in range(self.num_layers):
370
+ w = getattr(self, f'dense{i}')(w)
371
+
372
+ wp = None
373
+ if self.num_outputs is not None:
374
+ if self.repeat_output:
375
+ wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1))
376
+ else:
377
+ wp = w.reshape(-1, self.num_outputs, self.output_dim)
378
+
379
+ results = {
380
+ 'z': z,
381
+ 'label': label,
382
+ 'w': w,
383
+ 'wp': wp,
384
+ }
385
+ if self.label_dim > 0:
386
+ results['embedding'] = embedding
387
+ return results
388
+
389
+
390
+ class SynthesisNetwork(nn.Module):
391
+ """Implements the image synthesis module.
392
+
393
+ Basically, this module executes several convolutional layers in sequence.
394
+ """
395
+
396
+ def __init__(self,
397
+ resolution,
398
+ init_res,
399
+ w_dim,
400
+ image_channels,
401
+ final_tanh,
402
+ fused_scale,
403
+ fused_scale_res,
404
+ use_wscale,
405
+ wscale_gain,
406
+ lr_mul,
407
+ noise_type,
408
+ fmaps_base,
409
+ fmaps_max,
410
+ filter_kernel,
411
+ eps):
412
+ super().__init__()
413
+
414
+ self.init_res = init_res
415
+ self.init_res_log2 = int(np.log2(init_res))
416
+ self.resolution = resolution
417
+ self.final_res_log2 = int(np.log2(resolution))
418
+ self.w_dim = w_dim
419
+ self.image_channels = image_channels
420
+ self.final_tanh = final_tanh
421
+ self.fused_scale = fused_scale
422
+ self.fused_scale_res = fused_scale_res
423
+ self.use_wscale = use_wscale
424
+ self.wscale_gain = wscale_gain
425
+ self.lr_mul = lr_mul
426
+ self.noise_type = noise_type.lower()
427
+ self.fmaps_base = fmaps_base
428
+ self.fmaps_max = fmaps_max
429
+ self.eps = eps
430
+
431
+ self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
432
+
433
+ # Level-of-details (used for progressive training).
434
+ self.register_buffer('lod', torch.zeros(()))
435
+ self.pth_to_tf_var_mapping = {'lod': 'lod'}
436
+
437
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
438
+ res = 2 ** res_log2
439
+ in_channels = self.get_nf(res // 2)
440
+ out_channels = self.get_nf(res)
441
+ block_idx = res_log2 - self.init_res_log2
442
+
443
+ # First layer (kernel 3x3) with upsampling
444
+ layer_name = f'layer{2 * block_idx}'
445
+ if res == self.init_res:
446
+ self.add_module(layer_name,
447
+ ModulateConvLayer(in_channels=0,
448
+ out_channels=out_channels,
449
+ resolution=res,
450
+ w_dim=w_dim,
451
+ kernel_size=None,
452
+ add_bias=True,
453
+ scale_factor=None,
454
+ fused_scale=None,
455
+ filter_kernel=None,
456
+ use_wscale=use_wscale,
457
+ wscale_gain=wscale_gain,
458
+ lr_mul=lr_mul,
459
+ noise_type=noise_type,
460
+ activation_type='lrelu',
461
+ use_style=True,
462
+ eps=eps))
463
+ tf_layer_name = 'Const'
464
+ self.pth_to_tf_var_mapping[f'{layer_name}.const'] = (
465
+ f'{res}x{res}/{tf_layer_name}/const')
466
+ else:
467
+ self.add_module(
468
+ layer_name,
469
+ ModulateConvLayer(in_channels=in_channels,
470
+ out_channels=out_channels,
471
+ resolution=res,
472
+ w_dim=w_dim,
473
+ kernel_size=3,
474
+ add_bias=True,
475
+ scale_factor=2,
476
+ fused_scale=(res >= fused_scale_res
477
+ if fused_scale == 'auto'
478
+ else fused_scale),
479
+ filter_kernel=filter_kernel,
480
+ use_wscale=use_wscale,
481
+ wscale_gain=wscale_gain,
482
+ lr_mul=lr_mul,
483
+ noise_type=noise_type,
484
+ activation_type='lrelu',
485
+ use_style=True,
486
+ eps=eps))
487
+ tf_layer_name = 'Conv0_up'
488
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
489
+ f'{res}x{res}/{tf_layer_name}/weight')
490
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
491
+ f'{res}x{res}/{tf_layer_name}/bias')
492
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
493
+ f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
494
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
495
+ f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
496
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
497
+ f'{res}x{res}/{tf_layer_name}/Noise/weight')
498
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
499
+ f'noise{2 * block_idx}')
500
+
501
+ # Second layer (kernel 3x3) without upsampling.
502
+ layer_name = f'layer{2 * block_idx + 1}'
503
+ self.add_module(layer_name,
504
+ ModulateConvLayer(in_channels=out_channels,
505
+ out_channels=out_channels,
506
+ resolution=res,
507
+ w_dim=w_dim,
508
+ kernel_size=3,
509
+ add_bias=True,
510
+ scale_factor=1,
511
+ fused_scale=False,
512
+ filter_kernel=None,
513
+ use_wscale=use_wscale,
514
+ wscale_gain=wscale_gain,
515
+ lr_mul=lr_mul,
516
+ noise_type=noise_type,
517
+ activation_type='lrelu',
518
+ use_style=True,
519
+ eps=eps))
520
+ tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
521
+ self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
522
+ f'{res}x{res}/{tf_layer_name}/weight')
523
+ self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
524
+ f'{res}x{res}/{tf_layer_name}/bias')
525
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
526
+ f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
527
+ self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
528
+ f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
529
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
530
+ f'{res}x{res}/{tf_layer_name}/Noise/weight')
531
+ self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
532
+ f'noise{2 * block_idx + 1}')
533
+
534
+ # Output convolution layer for each resolution.
535
+ self.add_module(f'output{block_idx}',
536
+ ModulateConvLayer(in_channels=out_channels,
537
+ out_channels=image_channels,
538
+ resolution=res,
539
+ w_dim=w_dim,
540
+ kernel_size=1,
541
+ add_bias=True,
542
+ scale_factor=1,
543
+ fused_scale=False,
544
+ filter_kernel=None,
545
+ use_wscale=use_wscale,
546
+ wscale_gain=1.0,
547
+ lr_mul=lr_mul,
548
+ noise_type='none',
549
+ activation_type='linear',
550
+ use_style=False,
551
+ eps=eps))
552
+ self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
553
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
554
+ self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
555
+ f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
556
+
557
+ def get_nf(self, res):
558
+ """Gets number of feature maps according to the given resolution."""
559
+ return min(self.fmaps_base // res, self.fmaps_max)
560
+
561
+ def set_space_of_latent(self, space_of_latent):
562
+ """Sets the space to which the latent code belong.
563
+
564
+ This function is particularly used for choosing how to inject the latent
565
+ code into the convolutional layers. The original generator will take a
566
+ W-Space code and apply it for style modulation after an affine
567
+ transformation. But, sometimes, it may need to directly feed an already
568
+ affine-transformed code into the convolutional layer, e.g., when
569
+ training an encoder for GAN inversion. We term the transformed space as
570
+ Style Space (or Y-Space). This function is designed to tell the
571
+ convolutional layers how to use the input code.
572
+
573
+ Args:
574
+ space_of_latent: The space to which the latent code belong. Case
575
+ insensitive. Support `W` and `Y`.
576
+ """
577
+ space_of_latent = space_of_latent.upper()
578
+ for module in self.modules():
579
+ if isinstance(module, ModulateConvLayer) and module.use_style:
580
+ setattr(module, 'space_of_latent', space_of_latent)
581
+
582
+ def forward(self, wp, lod=None, noise_mode='const'):
583
+ lod = self.lod.item() if lod is None else lod
584
+ if lod + self.init_res_log2 > self.final_res_log2:
585
+ raise ValueError(f'Maximum level-of-details (lod) is '
586
+ f'{self.final_res_log2 - self.init_res_log2}, '
587
+ f'but `{lod}` is received!')
588
+
589
+ results = {'wp': wp}
590
+ x = None
591
+ for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
592
+ current_lod = self.final_res_log2 - res_log2
593
+ block_idx = res_log2 - self.init_res_log2
594
+ if lod < current_lod + 1:
595
+ layer = getattr(self, f'layer{2 * block_idx}')
596
+ x, style = layer(x, wp[:, 2 * block_idx], noise_mode)
597
+ results[f'style{2 * block_idx}'] = style
598
+ layer = getattr(self, f'layer{2 * block_idx + 1}')
599
+ x, style = layer(x, wp[:, 2 * block_idx + 1], noise_mode)
600
+ results[f'style{2 * block_idx + 1}'] = style
601
+ if current_lod - 1 < lod <= current_lod:
602
+ image = getattr(self, f'output{block_idx}')(x)
603
+ elif current_lod < lod < current_lod + 1:
604
+ alpha = np.ceil(lod) - lod
605
+ temp = getattr(self, f'output{block_idx}')(x)
606
+ image = F.interpolate(image, scale_factor=2, mode='nearest')
607
+ image = temp * alpha + image * (1 - alpha)
608
+ elif lod >= current_lod + 1:
609
+ image = F.interpolate(image, scale_factor=2, mode='nearest')
610
+
611
+ if self.final_tanh:
612
+ image = torch.tanh(image)
613
+ results['image'] = image
614
+ return results
615
+
616
+
617
+ class PixelNormLayer(nn.Module):
618
+ """Implements pixel-wise feature vector normalization layer."""
619
+
620
+ def __init__(self, dim, eps):
621
+ super().__init__()
622
+ self.dim = dim
623
+ self.eps = eps
624
+
625
+ def extra_repr(self):
626
+ return f'dim={self.dim}, epsilon={self.eps}'
627
+
628
+ def forward(self, x):
629
+ scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
630
+ return x * scale
631
+
632
+
633
+ class Blur(torch.autograd.Function):
634
+ """Defines blur operation with customized gradient computation."""
635
+
636
+ @staticmethod
637
+ def forward(ctx, x, kernel):
638
+ assert kernel.shape[2] == 3 and kernel.shape[3] == 3
639
+ ctx.save_for_backward(kernel)
640
+ y = F.conv2d(input=x,
641
+ weight=kernel,
642
+ bias=None,
643
+ stride=1,
644
+ padding=1,
645
+ groups=x.shape[1])
646
+ return y
647
+
648
+ @staticmethod
649
+ def backward(ctx, dy):
650
+ kernel, = ctx.saved_tensors
651
+ dx = F.conv2d(input=dy,
652
+ weight=kernel.flip((2, 3)),
653
+ bias=None,
654
+ stride=1,
655
+ padding=1,
656
+ groups=dy.shape[1])
657
+ return dx, None, None
658
+
659
+
660
+ class ModulateConvLayer(nn.Module):
661
+ """Implements the convolutional layer with style modulation."""
662
+
663
+ def __init__(self,
664
+ in_channels,
665
+ out_channels,
666
+ resolution,
667
+ w_dim,
668
+ kernel_size,
669
+ add_bias,
670
+ scale_factor,
671
+ fused_scale,
672
+ filter_kernel,
673
+ use_wscale,
674
+ wscale_gain,
675
+ lr_mul,
676
+ noise_type,
677
+ activation_type,
678
+ use_style,
679
+ eps):
680
+ """Initializes with layer settings.
681
+
682
+ Args:
683
+ in_channels: Number of channels of the input tensor.
684
+ out_channels: Number of channels of the output tensor.
685
+ resolution: Resolution of the output tensor.
686
+ w_dim: Dimension of W space for style modulation.
687
+ kernel_size: Size of the convolutional kernels.
688
+ add_bias: Whether to add bias onto the convolutional result.
689
+ scale_factor: Scale factor for upsampling.
690
+ fused_scale: Whether to fuse `upsample` and `conv2d` as one
691
+ operator, using transpose convolution.
692
+ filter_kernel: Kernel used for filtering.
693
+ use_wscale: Whether to use weight scaling.
694
+ wscale_gain: Gain factor for weight scaling.
695
+ lr_mul: Learning multiplier for both weight and bias.
696
+ noise_type: Type of noise added to the feature map after the
697
+ convolution (if needed). Support `none`, `spatial` and
698
+ `channel`.
699
+ activation_type: Type of activation.
700
+ use_style: Whether to apply style modulation.
701
+ eps: A small value to avoid divide overflow.
702
+ """
703
+ super().__init__()
704
+
705
+ self.in_channels = in_channels
706
+ self.out_channels = out_channels
707
+ self.resolution = resolution
708
+ self.w_dim = w_dim
709
+ self.kernel_size = kernel_size
710
+ self.add_bias = add_bias
711
+ self.scale_factor = scale_factor
712
+ self.fused_scale = fused_scale
713
+ self.filter_kernel = filter_kernel
714
+ self.use_wscale = use_wscale
715
+ self.wscale_gain = wscale_gain
716
+ self.lr_mul = lr_mul
717
+ self.noise_type = noise_type.lower()
718
+ self.activation_type = activation_type
719
+ self.use_style = use_style
720
+ self.eps = eps
721
+
722
+ # Set up noise.
723
+ if self.noise_type == 'none':
724
+ pass
725
+ elif self.noise_type == 'spatial':
726
+ self.register_buffer(
727
+ 'noise', torch.randn(1, 1, resolution, resolution))
728
+ self.noise_strength = nn.Parameter(
729
+ torch.zeros(1, out_channels, 1, 1))
730
+ elif self.noise_type == 'channel':
731
+ self.register_buffer(
732
+ 'noise', torch.randn(1, out_channels, 1, 1))
733
+ self.noise_strength = nn.Parameter(
734
+ torch.zeros(1, 1, resolution, resolution))
735
+ else:
736
+ raise NotImplementedError(f'Not implemented noise type: '
737
+ f'`{noise_type}`!')
738
+
739
+ # Set up bias.
740
+ if add_bias:
741
+ self.bias = nn.Parameter(torch.zeros(out_channels))
742
+ self.bscale = lr_mul
743
+ else:
744
+ self.bias = None
745
+
746
+ # Set up activation.
747
+ assert activation_type in ['linear', 'relu', 'lrelu']
748
+
749
+ # Set up style.
750
+ if use_style:
751
+ self.space_of_latent = 'W'
752
+ self.style = DenseLayer(in_channels=w_dim,
753
+ out_channels=out_channels * 2,
754
+ add_bias=True,
755
+ use_wscale=use_wscale,
756
+ wscale_gain=1.0,
757
+ lr_mul=1.0,
758
+ activation_type='linear')
759
+
760
+ if in_channels == 0: # First layer.
761
+ self.const = nn.Parameter(
762
+ torch.ones(1, out_channels, resolution, resolution))
763
+ return
764
+
765
+ # Set up weight.
766
+ weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
767
+ fan_in = kernel_size * kernel_size * in_channels
768
+ wscale = wscale_gain / np.sqrt(fan_in)
769
+ if use_wscale:
770
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
771
+ self.wscale = wscale * lr_mul
772
+ else:
773
+ self.weight = nn.Parameter(
774
+ torch.randn(*weight_shape) * wscale / lr_mul)
775
+ self.wscale = lr_mul
776
+
777
+ # Set up upsampling filter (if needed).
778
+ if scale_factor > 1:
779
+ assert filter_kernel is not None
780
+ kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1)
781
+ kernel = kernel.T.dot(kernel)
782
+ kernel = kernel / np.sum(kernel)
783
+ kernel = kernel[np.newaxis, np.newaxis]
784
+ self.register_buffer('filter', torch.from_numpy(kernel))
785
+
786
+ if scale_factor > 1 and fused_scale: # use transpose convolution.
787
+ self.stride = scale_factor
788
+ else:
789
+ self.stride = 1
790
+ self.padding = kernel_size // 2
791
+
792
+ def extra_repr(self):
793
+ return (f'in_ch={self.in_channels}, '
794
+ f'out_ch={self.out_channels}, '
795
+ f'ksize={self.kernel_size}, '
796
+ f'wscale_gain={self.wscale_gain:.3f}, '
797
+ f'bias={self.add_bias}, '
798
+ f'lr_mul={self.lr_mul:.3f}, '
799
+ f'upsample={self.scale_factor}, '
800
+ f'fused_scale={self.fused_scale}, '
801
+ f'upsample_filter={self.filter_kernel}, '
802
+ f'noise_type={self.noise_type}, '
803
+ f'act={self.activation_type}, '
804
+ f'use_style={self.use_style}')
805
+
806
+ def forward_style(self, w):
807
+ """Gets style code from the given input.
808
+
809
+ More specifically, if the input is from W-Space, it will be projected by
810
+ an affine transformation. If it is from the Style Space (Y-Space), no
811
+ operation is required.
812
+
813
+ NOTE: For codes from Y-Space, we use slicing to make sure the dimension
814
+ is correct, in case that the code is padded before fed into this layer.
815
+ """
816
+ space_of_latent = self.space_of_latent.upper()
817
+ if space_of_latent == 'W':
818
+ if w.ndim != 2 or w.shape[1] != self.w_dim:
819
+ raise ValueError(f'The input tensor should be with shape '
820
+ f'[batch_size, w_dim], where '
821
+ f'`w_dim` equals to {self.w_dim}!\n'
822
+ f'But `{w.shape}` is received!')
823
+ style = self.style(w)
824
+ elif space_of_latent == 'Y':
825
+ if w.ndim != 2 or w.shape[1] < self.out_channels * 2:
826
+ raise ValueError(f'The input tensor should be with shape '
827
+ f'[batch_size, y_dim], where '
828
+ f'`y_dim` equals to {self.out_channels * 2}!\n'
829
+ f'But `{w.shape}` is received!')
830
+ style = w[:, :self.out_channels * 2]
831
+ else:
832
+ raise NotImplementedError(f'Not implemented `space_of_latent`: '
833
+ f'`{space_of_latent}`!')
834
+ return style
835
+
836
+ def forward(self, x, w=None, noise_mode='const'):
837
+ if self.in_channels == 0:
838
+ assert x is None
839
+ x = self.const.repeat(w.shape[0], 1, 1, 1)
840
+ else:
841
+ weight = self.weight
842
+ if self.wscale != 1.0:
843
+ weight = weight * self.wscale
844
+
845
+ if self.scale_factor > 1 and self.fused_scale:
846
+ weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0)
847
+ weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
848
+ weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
849
+ x = F.conv_transpose2d(x,
850
+ weight=weight.transpose(0, 1),
851
+ bias=None,
852
+ stride=self.stride,
853
+ padding=self.padding)
854
+ else:
855
+ if self.scale_factor > 1:
856
+ up = self.scale_factor
857
+ x = F.interpolate(x, scale_factor=up, mode='nearest')
858
+ x = F.conv2d(x,
859
+ weight=weight,
860
+ bias=None,
861
+ stride=self.stride,
862
+ padding=self.padding)
863
+
864
+ if self.scale_factor > 1:
865
+ # Disable `autocast` for customized autograd function.
866
+ # Please check reference:
867
+ # https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions
868
+ with autocast(enabled=False):
869
+ f = self.filter.repeat(self.out_channels, 1, 1, 1)
870
+ x = Blur.apply(x.float(), f) # Always use FP32.
871
+
872
+ # Prepare noise.
873
+ noise_mode = noise_mode.lower()
874
+ if self.noise_type != 'none' and noise_mode != 'none':
875
+ if noise_mode == 'random':
876
+ noise = torch.randn(
877
+ (x.shape[0], *self.noise.shape[1:]), device=x.device)
878
+ elif noise_mode == 'const':
879
+ noise = self.noise
880
+ else:
881
+ raise ValueError(f'Unknown noise mode `{noise_mode}`!')
882
+ x = x + noise * self.noise_strength
883
+
884
+ if self.bias is not None:
885
+ bias = self.bias
886
+ if self.bscale != 1.0:
887
+ bias = bias * self.bscale
888
+ x = x + bias.reshape(1, self.out_channels, 1, 1)
889
+
890
+ if self.activation_type == 'linear':
891
+ pass
892
+ elif self.activation_type == 'relu':
893
+ x = F.relu(x, inplace=True)
894
+ elif self.activation_type == 'lrelu':
895
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
896
+ else:
897
+ raise NotImplementedError(f'Not implemented activation type '
898
+ f'`{self.activation_type}`!')
899
+
900
+ if not self.use_style:
901
+ return x
902
+
903
+ # Instance normalization.
904
+ x = x - x.mean(dim=(2, 3), keepdim=True)
905
+ scale = (x.square().mean(dim=(2, 3), keepdim=True) + self.eps).rsqrt()
906
+ x = x * scale
907
+ # Style modulation.
908
+ style = self.forward_style(w)
909
+ style_split = style.unsqueeze(2).unsqueeze(3).chunk(2, dim=1)
910
+ x = x * (style_split[0] + 1) + style_split[1]
911
+
912
+ return x, style
913
+
914
+
915
+ class DenseLayer(nn.Module):
916
+ """Implements the dense layer."""
917
+
918
+ def __init__(self,
919
+ in_channels,
920
+ out_channels,
921
+ add_bias,
922
+ use_wscale,
923
+ wscale_gain,
924
+ lr_mul,
925
+ activation_type):
926
+ """Initializes with layer settings.
927
+
928
+ Args:
929
+ in_channels: Number of channels of the input tensor.
930
+ out_channels: Number of channels of the output tensor.
931
+ add_bias: Whether to add bias onto the fully-connected result.
932
+ use_wscale: Whether to use weight scaling.
933
+ wscale_gain: Gain factor for weight scaling.
934
+ lr_mul: Learning multiplier for both weight and bias.
935
+ activation_type: Type of activation.
936
+ """
937
+ super().__init__()
938
+ self.in_channels = in_channels
939
+ self.out_channels = out_channels
940
+ self.add_bias = add_bias
941
+ self.use_wscale = use_wscale
942
+ self.wscale_gain = wscale_gain
943
+ self.lr_mul = lr_mul
944
+ self.activation_type = activation_type
945
+
946
+ weight_shape = (out_channels, in_channels)
947
+ wscale = wscale_gain / np.sqrt(in_channels)
948
+ if use_wscale:
949
+ self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
950
+ self.wscale = wscale * lr_mul
951
+ else:
952
+ self.weight = nn.Parameter(
953
+ torch.randn(*weight_shape) * wscale / lr_mul)
954
+ self.wscale = lr_mul
955
+
956
+ if add_bias:
957
+ self.bias = nn.Parameter(torch.zeros(out_channels))
958
+ self.bscale = lr_mul
959
+ else:
960
+ self.bias = None
961
+
962
+ assert activation_type in ['linear', 'relu', 'lrelu']
963
+
964
+ def extra_repr(self):
965
+ return (f'in_ch={self.in_channels}, '
966
+ f'out_ch={self.out_channels}, '
967
+ f'wscale_gain={self.wscale_gain:.3f}, '
968
+ f'bias={self.add_bias}, '
969
+ f'lr_mul={self.lr_mul:.3f}, '
970
+ f'act={self.activation_type}')
971
+
972
+ def forward(self, x):
973
+ if x.ndim != 2:
974
+ x = x.flatten(start_dim=1)
975
+
976
+ weight = self.weight
977
+ if self.wscale != 1.0:
978
+ weight = weight * self.wscale
979
+ bias = None
980
+ if self.bias is not None:
981
+ bias = self.bias
982
+ if self.bscale != 1.0:
983
+ bias = bias * self.bscale
984
+
985
+ x = F.linear(x, weight=weight, bias=bias)
986
+
987
+ if self.activation_type == 'linear':
988
+ pass
989
+ elif self.activation_type == 'relu':
990
+ x = F.relu(x, inplace=True)
991
+ elif self.activation_type == 'lrelu':
992
+ x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
993
+ else:
994
+ raise NotImplementedError(f'Not implemented activation type '
995
+ f'`{self.activation_type}`!')
996
+
997
+ return x
998
+
999
+ # pylint: enable=missing-function-docstring
models/test.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Unit test for loading pre-trained models.
3
+
4
+ Basically, this file tests whether the perceptual model (VGG16) and the
5
+ inception model (InceptionV3), which are commonly used for loss computation and
6
+ evaluation, have the expected behavior after loading pre-trained weights. In
7
+ particular, we compare with the models from repo
8
+
9
+ https://github.com/NVlabs/stylegan2-ada-pytorch
10
+ """
11
+
12
+ import torch
13
+
14
+ from models import build_model
15
+ from utils.misc import download_url
16
+
17
+ __all__ = ['test_model']
18
+
19
+ _BATCH_SIZE = 4
20
+ # pylint: disable=line-too-long
21
+ _PERCEPTUAL_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
22
+ _INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23
+ # pylint: enable=line-too-long
24
+
25
+
26
+ def test_model():
27
+ """Collects all model tests."""
28
+ torch.backends.cudnn.enabled = True
29
+ torch.backends.cudnn.allow_tf32 = False
30
+ torch.backends.cuda.matmul.allow_tf32 = False
31
+ torch.backends.cudnn.benchmark = False
32
+ torch.backends.cudnn.deterministic = True
33
+ print('========== Start Model Test ==========')
34
+ test_perceptual()
35
+ test_inception()
36
+ print('========== Finish Model Test ==========')
37
+
38
+
39
+ def test_perceptual():
40
+ """Test the perceptual model."""
41
+ print('===== Testing Perceptual Model =====')
42
+
43
+ print('Build test model.')
44
+ model = build_model('PerceptualModel',
45
+ use_torchvision=False,
46
+ no_top=False,
47
+ enable_lpips=True)
48
+
49
+ print('Build reference model.')
50
+ ref_model_path, _, = download_url(_PERCEPTUAL_URL)
51
+ with open(ref_model_path, 'rb') as f:
52
+ ref_model = torch.jit.load(f).eval().cuda()
53
+
54
+ print('Test performance: ')
55
+ for size in [224, 128, 256, 512, 1024]:
56
+ raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
57
+ raw_img_comp = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
58
+
59
+ # The test model requires input images to have range [-1, 1].
60
+ img = raw_img.to(torch.float32).cuda() / 127.5 - 1
61
+ img_comp = raw_img_comp.to(torch.float32).cuda() / 127.5 - 1
62
+ feat = model(img, resize_input=True, return_tensor='feature')
63
+ pred = model(img, resize_input=True, return_tensor='prediction')
64
+ lpips = model(img, img_comp, resize_input=False, return_tensor='lpips')
65
+ assert feat.shape == (_BATCH_SIZE, 4096)
66
+ assert pred.shape == (_BATCH_SIZE, 1000)
67
+ assert lpips.shape == (_BATCH_SIZE,)
68
+
69
+ # The reference model requires input images to have range [0, 255].
70
+ img = raw_img.to(torch.float32).cuda()
71
+ img_comp = raw_img_comp.to(torch.float32).cuda()
72
+ ref_feat = ref_model(img, resize_images=True, return_features=True)
73
+ ref_pred = ref_model(img, resize_images=True, return_features=False)
74
+ temp = ref_model(torch.cat([img, img_comp], dim=0),
75
+ resize_images=False, return_lpips=True).chunk(2)
76
+ ref_lpips = (temp[0] - temp[1]).square().sum(dim=1, keepdim=False)
77
+ assert ref_feat.shape == (_BATCH_SIZE, 4096)
78
+ assert ref_pred.shape == (_BATCH_SIZE, 1000)
79
+ assert ref_lpips.shape == (_BATCH_SIZE,)
80
+
81
+ print(f' Size {size}x{size}, feature (with resize):\n '
82
+ f'mean: {(feat - ref_feat).abs().mean().item():.3e}, '
83
+ f'max: {(feat - ref_feat).abs().max().item():.3e}, '
84
+ f'ref_mean: {ref_feat.abs().mean().item():.3e}, '
85
+ f'ref_max: {ref_feat.abs().max().item():.3e}.')
86
+ print(f' Size {size}x{size}, prediction (with resize):\n '
87
+ f'mean: {(pred - ref_pred).abs().mean().item():.3e}, '
88
+ f'max: {(pred - ref_pred).abs().max().item():.3e}, '
89
+ f'ref_mean: {ref_pred.abs().mean().item():.3e}, '
90
+ f'ref_max: {ref_pred.abs().max().item():.3e}.')
91
+ print(f' Size {size}x{size}, LPIPS (without resize):\n '
92
+ f'mean: {(lpips - ref_lpips).abs().mean().item():.3e}, '
93
+ f'max: {(lpips - ref_lpips).abs().max().item():.3e}, '
94
+ f'ref_mean: {ref_lpips.abs().mean().item():.3e}, '
95
+ f'ref_max: {ref_lpips.abs().max().item():.3e}.')
96
+
97
+
98
+ def test_inception():
99
+ """Test the inception model."""
100
+ print('===== Testing Inception Model =====')
101
+
102
+ print('Build test model.')
103
+ model = build_model('InceptionModel', align_tf=True)
104
+
105
+ print('Build reference model.')
106
+ ref_model_path, _, = download_url(_INCEPTION_URL)
107
+ with open(ref_model_path, 'rb') as f:
108
+ ref_model = torch.jit.load(f).eval().cuda()
109
+
110
+ print('Test performance: ')
111
+ for size in [299, 128, 256, 512, 1024]:
112
+ raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
113
+
114
+ # The test model requires input images to have range [-1, 1].
115
+ img = raw_img.to(torch.float32).cuda() / 127.5 - 1
116
+ feat = model(img)
117
+ pred = model(img, output_predictions=True)
118
+ pred_nb = model(img, output_predictions=True, remove_logits_bias=True)
119
+ assert feat.shape == (_BATCH_SIZE, 2048)
120
+ assert pred.shape == (_BATCH_SIZE, 1008)
121
+ assert pred_nb.shape == (_BATCH_SIZE, 1008)
122
+
123
+ # The reference model requires input images to have range [0, 255].
124
+ img = raw_img.to(torch.float32).cuda()
125
+ ref_feat = ref_model(img, return_features=True)
126
+ ref_pred = ref_model(img)
127
+ ref_pred_nb = ref_model(img, no_output_bias=True)
128
+ assert ref_feat.shape == (_BATCH_SIZE, 2048)
129
+ assert ref_pred.shape == (_BATCH_SIZE, 1008)
130
+ assert ref_pred_nb.shape == (_BATCH_SIZE, 1008)
131
+
132
+ print(f' Size {size}x{size}, feature:\n '
133
+ f'mean: {(feat - ref_feat).abs().mean().item():.3e}, '
134
+ f'max: {(feat - ref_feat).abs().max().item():.3e}, '
135
+ f'ref_mean: {ref_feat.abs().mean().item():.3e}, '
136
+ f'ref_max: {ref_feat.abs().max().item():.3e}.')
137
+ print(f' Size {size}x{size}, prediction:\n '
138
+ f'mean: {(pred - ref_pred).abs().mean().item():.3e}, '
139
+ f'max: {(pred - ref_pred).abs().max().item():.3e}, '
140
+ f'ref_mean: {ref_pred.abs().mean().item():.3e}, '
141
+ f'ref_max: {ref_pred.abs().max().item():.3e}.')
142
+ print(f' Size {size}x{size}, prediction (without bias):\n '
143
+ f'mean: {(pred_nb - ref_pred_nb).abs().mean().item():.3e}, '
144
+ f'max: {(pred_nb - ref_pred_nb).abs().max().item():.3e}, '
145
+ f'ref_mean: {ref_pred_nb.abs().mean().item():.3e}, '
146
+ f'ref_max: {ref_pred_nb.abs().max().item():.3e}.')
models/utils/__init__.py ADDED
File without changes
models/utils/ops.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains operators for neural networks."""
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ __all__ = ['all_gather']
8
+
9
+
10
+ def all_gather(tensor):
11
+ """Gathers tensor from all devices and executes averaging."""
12
+ if not dist.is_initialized():
13
+ return tensor
14
+
15
+ world_size = dist.get_world_size()
16
+ tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
17
+ dist.all_gather(tensor_list, tensor, async_op=False)
18
+ return torch.stack(tensor_list, dim=0).mean(dim=0)
requirements/convert.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.8.1
2
+ tensorflow-gpu==1.15
3
+ ninja==1.10.2
4
+ scikit-video==1.1.11
5
+ pillow==9.0.0
6
+ opencv-python-headless==4.5.5.62
7
+ requests
8
+ bs4
9
+ tqdm
10
+ rich
11
+ easydict
requirements/develop.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ bpytop # Monitor system resources.
2
+ gpustat # Monitor GPU usage.
3
+ pylint # Check coding style.
requirements/minimal.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.8.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
2
+ torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3
+ tensorboard==2.7.0
4
+ torch-tb-profiler==0.3.1
5
+ ninja==1.10.2
6
+ numpy==1.21.5
7
+ scipy==1.7.3
8
+ scikit-learn==1.0.2
9
+ scikit-video==1.1.11
10
+ pillow==9.0.0
11
+ opencv-python-headless==4.5.5.62
12
+ requests
13
+ bs4
14
+ tqdm
15
+ rich
16
+ click
17
+ cloup
18
+ psutil
19
+ easydict
20
+ lmdb
21
+ matplotlib
synthesis.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Script that synthesizes images with pre-trained models.
3
+
4
+ Support StyleGAN2 and StyleGAN3.
5
+ """
6
+
7
+ import os
8
+ import argparse
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+
12
+ import torch
13
+ from models import build_model
14
+ from utils.visualizers.html_visualizer import HtmlVisualizer
15
+ from utils.image_utils import save_image, resize_image
16
+ from utils.image_utils import postprocess_image
17
+ from utils.custom_utils import to_numpy
18
+
19
+
20
+ def parse_args():
21
+ """Parses arguments."""
22
+ parser = argparse.ArgumentParser()
23
+ group = parser.add_argument_group('General options.')
24
+ group.add_argument('weight_path', type=str,
25
+ help='Weight path to the pre-trained model.')
26
+ group.add_argument('--save_dir', type=str, default=None,
27
+ help='Directory to save the results. If not specified, '
28
+ 'the results will be saved to '
29
+ '`work_dirs/{TASK_SPECIFIC}/` by default.')
30
+ group.add_argument('--job', type=str, default='synthesize',
31
+ help='Name for the job. (default: synthesize)')
32
+ group.add_argument('--seed', type=int, default=4,
33
+ help='Seed for sampling. (default: 4)')
34
+ group.add_argument('--nums', type=int, default=100,
35
+ help='Number of samples to synthesized. (default: 100)')
36
+ group.add_argument('--img_size', type=int, default=1024,
37
+ help='Size of the synthesized images. (default: 1024)')
38
+ group.add_argument('--vis_size', type=int, default=256,
39
+ help='Size of the visualize images. (default: 256)')
40
+ group.add_argument('--w_dim', type=int, default=512,
41
+ help='Dimension of the latent w. (default: 512)')
42
+ group.add_argument('--batch_size', type=int, default=4,
43
+ help='Batch size. (default: 4)')
44
+ group.add_argument('--save_jpg', action='store_true', default=False,
45
+ help='Whether to save raw image. (default: False)')
46
+ group.add_argument('-d', '--data_name', type=str, default='ffhq',
47
+ help='Name of the datasets. (default: ffhq)')
48
+ group.add_argument('--latent_path', type=str, default='',
49
+ help='Path to the given latent codes. (default: None)')
50
+ group.add_argument('--trunc_psi', type=float, default=0.7,
51
+ help='Psi factor used for truncation. (default: 0.7)')
52
+ group.add_argument('--trunc_layers', type=int, default=8,
53
+ help='Number of layers to perform truncation.'
54
+ ' (default: 8)')
55
+
56
+ group = parser.add_argument_group('StyleGAN2')
57
+ group.add_argument('--stylegan2', action='store_true',
58
+ help='Whether or not using StyleGAN2. (default: False)')
59
+ group.add_argument('--scale_stylegan2', type=float, default=1.0,
60
+ help='Scale for the number of channel fro stylegan2.')
61
+ group.add_argument('--randomize_noise', type=str, default='const',
62
+ help='Noise type when synthesizing. (const or random)')
63
+
64
+ group = parser.add_argument_group('StyleGAN3')
65
+ group.add_argument('--stylegan3', action='store_true',
66
+ help='Whether or not using StyleGAN3. (default: False)')
67
+ group.add_argument('--cfg', type=str, default='T',
68
+ help='Config of the stylegan3 (T/R).')
69
+ group.add_argument('--scale_stylegan3r', type=float, default=2.0,
70
+ help='Scale for the number of channel for stylegan3 R.')
71
+ group.add_argument('--scale_stylegan3t', type=float, default=1.0,
72
+ help='Scale for the number of channel for stylegan3 T.')
73
+ group.add_argument('--tx', type=float, default=0,
74
+ help='Translate X-coordinate. (default: 0.0)')
75
+ group.add_argument('--ty', type=float, default=0,
76
+ help='Translate Y-coordinate. (default: 0.0)')
77
+ group.add_argument('--rotate', type=float, default=0,
78
+ help='Rotation angle in degrees. (default: 0)')
79
+ return parser.parse_args()
80
+
81
+
82
+ def main():
83
+ """Main function."""
84
+ args = parse_args()
85
+ # Parse model configuration.
86
+ assert (args.stylegan2 and not args.stylegan3) or \
87
+ (not args.stylegan2 and args.stylegan3)
88
+ job_disc = ''
89
+ if args.stylegan2:
90
+ config = dict(model_type='StyleGAN2Generator',
91
+ resolution=args.img_size,
92
+ w_dim=args.w_dim,
93
+ fmaps_base=int(args.scale_stylegan2 * (32 << 10)),
94
+ fmaps_max=512,)
95
+ job_disc += 'stylegan2'
96
+ else:
97
+ if args.stylegan3 and args.cfg == 'R':
98
+ config = dict(model_type='StyleGAN3Generator',
99
+ resolution=args.img_size,
100
+ w_dim=args.w_dim,
101
+ fmaps_base=int(args.scale_stylegan3r * (32 << 10)),
102
+ fmaps_max=1024,
103
+ use_radial_filter=True,)
104
+ job_disc += 'stylegan3r'
105
+ elif args.stylegan3 and args.cfg == 'T':
106
+ config = dict(model_type='StyleGAN3Generator',
107
+ resolution=args.img_size,
108
+ w_dim=args.w_dim,
109
+ fmaps_base=int(args.scale_stylegan3t * (32 << 10)),
110
+ fmaps_max=512,
111
+ use_radial_filter=False,
112
+ kernel_size=3,)
113
+ job_disc += 'stylegan3t'
114
+ else:
115
+ raise TypeError(f'StyleGAN3 config type error, need `R/T`,'
116
+ f' but got {args.cfg} instead.')
117
+
118
+ # Get work directory and job name.
119
+ save_dir = args.save_dir or f'work_dirs/{args.job}/{args.data_name}'
120
+ os.makedirs(save_dir, exist_ok=True)
121
+ job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}'
122
+ os.makedirs(f'{save_dir}/{job_name}', exist_ok=True)
123
+
124
+ # Build generation and get synthesis kwargs.
125
+ print('Building generator...')
126
+ generator = build_model(**config)
127
+ synthesis_kwargs = dict(trunc_psi=args.trunc_psi,
128
+ trunc_layers=args.trunc_layers,)
129
+ # Load pre-trained weights.
130
+ checkpoint_path = args.weight_path
131
+ print(f'Loading checkpoint from `{checkpoint_path}` ...')
132
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
133
+ if 'generator_smooth' in checkpoint:
134
+ generator.load_state_dict(checkpoint['generator_smooth'])
135
+ else:
136
+ generator.load_state_dict(checkpoint['generator'])
137
+ generator = generator.eval().cuda()
138
+ print('Finish loading checkpoint.')
139
+
140
+ np.random.seed(args.seed)
141
+ torch.manual_seed(args.seed)
142
+ if os.path.exists(args.latent_path):
143
+ latent_zs = np.load(args.latent_path)
144
+ latent_zs = latent_zs[:args.nums]
145
+ else:
146
+ latent_zs = np.random.randn(args.nums, generator.z_dim)
147
+ num_images = latent_zs.shape[0]
148
+ latent_zs = torch.from_numpy(latent_zs.astype(np.float32))
149
+ html = HtmlVisualizer(grid_size=num_images)
150
+ print(f'Synthesizing {num_images} images ...')
151
+ latent_ws = []
152
+ for batch_idx in tqdm(range(0, num_images, args.batch_size)):
153
+ latent_z = latent_zs[batch_idx:batch_idx + args.batch_size]
154
+ latent_z = latent_z.cuda()
155
+ with torch.no_grad():
156
+ g_outputs = generator(latent_z, **synthesis_kwargs)
157
+ g_image = to_numpy(g_outputs['image'])
158
+ images = postprocess_image(g_image)
159
+ for idx in range(images.shape[0]):
160
+ sub_idx = batch_idx + idx
161
+ img = images[idx]
162
+ row_idx, col_idx = divmod(sub_idx, html.num_cols)
163
+ image = resize_image(img, (args.vis_size, args.vis_size))
164
+ html.set_cell(row_idx, col_idx, image=image,
165
+ text=f'Sample {sub_idx:06d}')
166
+ if args.save_jpg:
167
+ save_path = f'{save_dir}/{job_name}/{sub_idx:06d}.jpg'
168
+ save_image(save_path, img)
169
+ latent_ws.append(to_numpy(g_outputs['wp']))
170
+ latent_ws = np.concatenate(latent_ws, axis=0)
171
+ print(f'shape of the latent code: {latent_ws.shape}')
172
+ np.save(f'{save_dir}/{job_name}/latent_codes.npy', latent_ws)
173
+ html.save(f'{save_dir}/{job_name}.html')
174
+ print(f'Finish synthesizing {num_images} samples.')
175
+
176
+
177
+ if __name__ == '__main__':
178
+ main()
third_party/__init__.py ADDED
File without changes
third_party/stylegan2_official_ops/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Operators for StyleGAN2
2
+
3
+ All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
4
+
5
+ - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
6
+ - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
7
+ - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
8
+ - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
9
+ - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
10
+ - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
11
+ - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
12
+ - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
13
+ - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
14
+
15
+ We make following slight modifications beyond disabling some lint warnings:
16
+
17
+ - Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
18
+ - Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators.
19
+ - Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
20
+ - Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
21
+ - Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
22
+ - Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
23
+ - Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
24
+ - Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
25
+ - Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
26
+ - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
27
+
28
+ Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.
third_party/stylegan2_official_ops/__init__.py ADDED
File without changes
third_party/stylegan2_official_ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
third_party/stylegan2_official_ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
third_party/stylegan2_official_ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
third_party/stylegan2_official_ops/bias_act.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom ops to fuse bias and activation as one operator, which is efficient.
12
+
13
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14
+ """
15
+
16
+ # pylint: disable=line-too-long
17
+ # pylint: disable=missing-class-docstring
18
+ # pylint: disable=global-statement
19
+ # pylint: disable=bare-except
20
+
21
+ import os
22
+ import warnings
23
+ import traceback
24
+ from easydict import EasyDict
25
+ import numpy as np
26
+ import torch
27
+
28
+ from . import custom_ops
29
+ from . import misc
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ activation_funcs = {
34
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
35
+ 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
36
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
37
+ 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
38
+ 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
39
+ 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
40
+ 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
41
+ 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
42
+ 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
43
+ }
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ _inited = False
48
+ _plugin = None
49
+ _null_tensor = torch.empty([0])
50
+
51
+ def _init():
52
+ global _inited, _plugin
53
+ if not _inited:
54
+ _inited = True
55
+ sources = ['bias_act.cpp', 'bias_act.cu']
56
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
57
+ try:
58
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
59
+ except:
60
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
61
+ return _plugin is not None
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
66
+ r"""Fused bias and activation function.
67
+
68
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
69
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
70
+ the fused op is considerably more efficient than performing the same calculation
71
+ using standard PyTorch ops. It supports first and second order gradients,
72
+ but not third order gradients.
73
+
74
+ Args:
75
+ x: Input activation tensor. Can be of any shape.
76
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
77
+ as `x`. The shape must be known, and it must match the dimension of `x`
78
+ corresponding to `dim`.
79
+ dim: The dimension in `x` corresponding to the elements of `b`.
80
+ The value of `dim` is ignored if `b` is not specified.
81
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
82
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
83
+ See `activation_funcs` for a full list. `None` is not allowed.
84
+ alpha: Shape parameter for the activation function, or `None` to use the default.
85
+ gain: Scaling factor for the output tensor, or `None` to use default.
86
+ See `activation_funcs` for the default scaling of each activation function.
87
+ If unsure, consider specifying 1.
88
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
89
+ the clamping (default).
90
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
91
+
92
+ Returns:
93
+ Tensor of the same shape and datatype as `x`.
94
+ """
95
+ assert isinstance(x, torch.Tensor)
96
+ assert impl in ['ref', 'cuda']
97
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
98
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
99
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
100
+
101
+ #----------------------------------------------------------------------------
102
+
103
+ @misc.profiled_function
104
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
105
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
106
+ """
107
+ assert isinstance(x, torch.Tensor)
108
+ assert clamp is None or clamp >= 0
109
+ spec = activation_funcs[act]
110
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
111
+ gain = float(gain if gain is not None else spec.def_gain)
112
+ clamp = float(clamp if clamp is not None else -1)
113
+
114
+ # Add bias.
115
+ if b is not None:
116
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
117
+ assert 0 <= dim < x.ndim
118
+ assert b.shape[0] == x.shape[dim]
119
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
120
+
121
+ # Evaluate activation function.
122
+ alpha = float(alpha)
123
+ x = spec.func(x, alpha=alpha)
124
+
125
+ # Scale by gain.
126
+ gain = float(gain)
127
+ if gain != 1:
128
+ x = x * gain
129
+
130
+ # Clamp.
131
+ if clamp >= 0:
132
+ x = x.clamp(-clamp, clamp)
133
+ return x
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ _bias_act_cuda_cache = dict()
138
+
139
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
140
+ """Fast CUDA implementation of `bias_act()` using custom ops.
141
+ """
142
+ # Parse arguments.
143
+ assert clamp is None or clamp >= 0
144
+ spec = activation_funcs[act]
145
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
146
+ gain = float(gain if gain is not None else spec.def_gain)
147
+ clamp = float(clamp if clamp is not None else -1)
148
+
149
+ # Lookup from cache.
150
+ key = (dim, act, alpha, gain, clamp)
151
+ if key in _bias_act_cuda_cache:
152
+ return _bias_act_cuda_cache[key]
153
+
154
+ # Forward op.
155
+ class BiasActCuda(torch.autograd.Function):
156
+ @staticmethod
157
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
158
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
159
+ x = x.contiguous(memory_format=ctx.memory_format)
160
+ b = b.contiguous() if b is not None else _null_tensor
161
+ y = x
162
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
163
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
164
+ ctx.save_for_backward(
165
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
166
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
167
+ y if 'y' in spec.ref else _null_tensor)
168
+ return y
169
+
170
+ @staticmethod
171
+ def backward(ctx, dy): # pylint: disable=arguments-differ
172
+ dy = dy.contiguous(memory_format=ctx.memory_format)
173
+ x, b, y = ctx.saved_tensors
174
+ dx = None
175
+ db = None
176
+
177
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
178
+ dx = dy
179
+ if act != 'linear' or gain != 1 or clamp >= 0:
180
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
181
+
182
+ if ctx.needs_input_grad[1]:
183
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
184
+
185
+ return dx, db
186
+
187
+ # Backward op.
188
+ class BiasActCudaGrad(torch.autograd.Function):
189
+ @staticmethod
190
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
191
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
192
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
193
+ ctx.save_for_backward(
194
+ dy if spec.has_2nd_grad else _null_tensor,
195
+ x, b, y)
196
+ return dx
197
+
198
+ @staticmethod
199
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
200
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
201
+ dy, x, b, y = ctx.saved_tensors
202
+ d_dy = None
203
+ d_x = None
204
+ d_b = None
205
+ d_y = None
206
+
207
+ if ctx.needs_input_grad[0]:
208
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
209
+
210
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
211
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
212
+
213
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
214
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
215
+
216
+ return d_dy, d_x, d_b, d_y
217
+
218
+ # Add to cache.
219
+ _bias_act_cuda_cache[key] = BiasActCuda
220
+ return BiasActCuda
221
+
222
+ #----------------------------------------------------------------------------
223
+
224
+ # pylint: enable=line-too-long
225
+ # pylint: enable=missing-class-docstring
226
+ # pylint: enable=global-statement
227
+ # pylint: enable=bare-except
third_party/stylegan2_official_ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom replacement for convolution operators.
12
+
13
+ Operators in this file support arbitrarily high order gradients with zero
14
+ performance penalty. Please set `impl` as `cuda` to use faster customized
15
+ operators, OR as `ref` to use native `torch.nn.functional.conv2d` and
16
+ `torch.nn.functional.conv_transpose2d`.
17
+
18
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
19
+ """
20
+
21
+ # pylint: disable=redefined-builtin
22
+ # pylint: disable=arguments-differ
23
+ # pylint: disable=protected-access
24
+ # pylint: disable=line-too-long
25
+ # pylint: disable=global-statement
26
+ # pylint: disable=missing-class-docstring
27
+ # pylint: disable=missing-function-docstring
28
+
29
+ import warnings
30
+ import contextlib
31
+ import torch
32
+
33
+ enabled = True # Enable the custom op by setting this to true.
34
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
35
+
36
+ @contextlib.contextmanager
37
+ def no_weight_gradients():
38
+ global weight_gradients_disabled
39
+ old = weight_gradients_disabled
40
+ weight_gradients_disabled = True
41
+ yield
42
+ weight_gradients_disabled = old
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'):
47
+ if impl == 'cuda' and _should_use_custom_op(input):
48
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
49
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
50
+
51
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'):
52
+ if impl == 'cuda' and _should_use_custom_op(input):
53
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
54
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ def _should_use_custom_op(input):
59
+ assert isinstance(input, torch.Tensor)
60
+ if (not enabled) or (not torch.backends.cudnn.enabled):
61
+ return False
62
+ if input.device.type != 'cuda':
63
+ return False
64
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
65
+ return True
66
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
67
+ return False
68
+
69
+ def _tuple_of_ints(xs, ndim):
70
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
71
+ assert len(xs) == ndim
72
+ assert all(isinstance(x, int) for x in xs)
73
+ return xs
74
+
75
+ #----------------------------------------------------------------------------
76
+
77
+ _conv2d_gradfix_cache = dict()
78
+
79
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
80
+ # Parse arguments.
81
+ ndim = 2
82
+ weight_shape = tuple(weight_shape)
83
+ stride = _tuple_of_ints(stride, ndim)
84
+ padding = _tuple_of_ints(padding, ndim)
85
+ output_padding = _tuple_of_ints(output_padding, ndim)
86
+ dilation = _tuple_of_ints(dilation, ndim)
87
+
88
+ # Lookup from cache.
89
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
90
+ if key in _conv2d_gradfix_cache:
91
+ return _conv2d_gradfix_cache[key]
92
+
93
+ # Validate arguments.
94
+ assert groups >= 1
95
+ assert len(weight_shape) == ndim + 2
96
+ assert all(stride[i] >= 1 for i in range(ndim))
97
+ assert all(padding[i] >= 0 for i in range(ndim))
98
+ assert all(dilation[i] >= 0 for i in range(ndim))
99
+ if not transpose:
100
+ assert all(output_padding[i] == 0 for i in range(ndim))
101
+ else: # transpose
102
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
103
+
104
+ # Helpers.
105
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
106
+ def calc_output_padding(input_shape, output_shape):
107
+ if transpose:
108
+ return [0, 0]
109
+ return [
110
+ input_shape[i + 2]
111
+ - (output_shape[i + 2] - 1) * stride[i]
112
+ - (1 - 2 * padding[i])
113
+ - dilation[i] * (weight_shape[i + 2] - 1)
114
+ for i in range(ndim)
115
+ ]
116
+
117
+ # Forward & backward.
118
+ class Conv2d(torch.autograd.Function):
119
+ @staticmethod
120
+ def forward(ctx, input, weight, bias):
121
+ assert weight.shape == weight_shape
122
+ if not transpose:
123
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
124
+ else: # transpose
125
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
126
+ ctx.save_for_backward(input, weight)
127
+ return output
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ input, weight = ctx.saved_tensors
132
+ grad_input = None
133
+ grad_weight = None
134
+ grad_bias = None
135
+
136
+ if ctx.needs_input_grad[0]:
137
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
138
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
139
+ assert grad_input.shape == input.shape
140
+
141
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
142
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
143
+ assert grad_weight.shape == weight_shape
144
+
145
+ if ctx.needs_input_grad[2]:
146
+ grad_bias = grad_output.sum([0, 2, 3])
147
+
148
+ return grad_input, grad_weight, grad_bias
149
+
150
+ # Gradient with respect to the weights.
151
+ class Conv2dGradWeight(torch.autograd.Function):
152
+ @staticmethod
153
+ def forward(ctx, grad_output, input):
154
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
155
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
156
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
157
+ assert grad_weight.shape == weight_shape
158
+ ctx.save_for_backward(grad_output, input)
159
+ return grad_weight
160
+
161
+ @staticmethod
162
+ def backward(ctx, grad2_grad_weight):
163
+ grad_output, input = ctx.saved_tensors
164
+ grad2_grad_output = None
165
+ grad2_input = None
166
+
167
+ if ctx.needs_input_grad[0]:
168
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
169
+ assert grad2_grad_output.shape == grad_output.shape
170
+
171
+ if ctx.needs_input_grad[1]:
172
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
173
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
174
+ assert grad2_input.shape == input.shape
175
+
176
+ return grad2_grad_output, grad2_input
177
+
178
+ _conv2d_gradfix_cache[key] = Conv2d
179
+ return Conv2d
180
+
181
+ #----------------------------------------------------------------------------
182
+
183
+ # pylint: enable=redefined-builtin
184
+ # pylint: enable=arguments-differ
185
+ # pylint: enable=protected-access
186
+ # pylint: enable=line-too-long
187
+ # pylint: enable=global-statement
188
+ # pylint: enable=missing-class-docstring
189
+ # pylint: enable=missing-function-docstring
third_party/stylegan2_official_ops/conv2d_resample.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """2D convolution with optional up/downsampling.
12
+
13
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14
+ """
15
+
16
+ # pylint: disable=line-too-long
17
+
18
+ import torch
19
+
20
+ from . import misc
21
+ from . import conv2d_gradfix
22
+ from . import upfirdn2d
23
+ from .upfirdn2d import _parse_padding
24
+ from .upfirdn2d import _get_filter_size
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def _get_weight_shape(w):
29
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
30
+ shape = [int(sz) for sz in w.shape]
31
+ misc.assert_shape(w, shape)
32
+ return shape
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'):
37
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
38
+ """
39
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
40
+
41
+ # Flip weight if requested.
42
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
43
+ w = w.flip([2, 3])
44
+
45
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
46
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
47
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
48
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
49
+ if out_channels <= 4 and groups == 1:
50
+ in_shape = x.shape
51
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
52
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
53
+ else:
54
+ x = x.to(memory_format=torch.contiguous_format)
55
+ w = w.to(memory_format=torch.contiguous_format)
56
+ x = conv2d_gradfix.conv2d(x, w, groups=groups, impl=impl)
57
+ return x.to(memory_format=torch.channels_last)
58
+
59
+ # Otherwise => execute using conv2d_gradfix.
60
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
61
+ return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl)
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ @misc.profiled_function
66
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'):
67
+ r"""2D convolution with optional up/downsampling.
68
+
69
+ Padding is performed only once at the beginning, not between the operations.
70
+
71
+ Args:
72
+ x: Input tensor of shape
73
+ `[batch_size, in_channels, in_height, in_width]`.
74
+ w: Weight tensor of shape
75
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
76
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
77
+ calling upfirdn2d.setup_filter(). None = identity (default).
78
+ up: Integer upsampling factor (default: 1).
79
+ down: Integer downsampling factor (default: 1).
80
+ padding: Padding with respect to the upsampled image. Can be a single number
81
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
82
+ (default: 0).
83
+ groups: Split input channels into N groups (default: 1).
84
+ flip_weight: False = convolution, True = correlation (default: True).
85
+ flip_filter: False = convolution, True = correlation (default: False).
86
+ impl: Implementation mode of customized ops. 'ref' for native PyTorch
87
+ implementation, 'cuda' for `.cu` implementation
88
+ (default: 'cuda').
89
+
90
+ Returns:
91
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
92
+ """
93
+ # Validate arguments.
94
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
95
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
96
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
97
+ assert isinstance(up, int) and (up >= 1)
98
+ assert isinstance(down, int) and (down >= 1)
99
+ assert isinstance(groups, int) and (groups >= 1)
100
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
101
+ fw, fh = _get_filter_size(f)
102
+ px0, px1, py0, py1 = _parse_padding(padding)
103
+
104
+ # Adjust padding to account for up/downsampling.
105
+ if up > 1:
106
+ px0 += (fw + up - 1) // 2
107
+ px1 += (fw - up) // 2
108
+ py0 += (fh + up - 1) // 2
109
+ py1 += (fh - up) // 2
110
+ if down > 1:
111
+ px0 += (fw - down + 1) // 2
112
+ px1 += (fw - down) // 2
113
+ py0 += (fh - down + 1) // 2
114
+ py1 += (fh - down) // 2
115
+
116
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
117
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
118
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
119
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
120
+ return x
121
+
122
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
123
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
124
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
125
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
126
+ return x
127
+
128
+ # Fast path: downsampling only => use strided convolution.
129
+ if down > 1 and up == 1:
130
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
131
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl)
132
+ return x
133
+
134
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
135
+ if up > 1:
136
+ if groups == 1:
137
+ w = w.transpose(0, 1)
138
+ else:
139
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
140
+ w = w.transpose(1, 2)
141
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
142
+ px0 -= kw - 1
143
+ px1 -= kw - up
144
+ py0 -= kh - 1
145
+ py1 -= kh - up
146
+ pxt = max(min(-px0, -px1), 0)
147
+ pyt = max(min(-py0, -py1), 0)
148
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl)
149
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl)
150
+ if down > 1:
151
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
152
+ return x
153
+
154
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
155
+ if up == 1 and down == 1:
156
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
157
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl)
158
+
159
+ # Fallback: Generic reference implementation.
160
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
161
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
162
+ if down > 1:
163
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
164
+ return x
165
+
166
+ #----------------------------------------------------------------------------
167
+
168
+ # pylint: enable=line-too-long
third_party/stylegan2_official_ops/custom_ops.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Utility functions to setup customized operators.
12
+
13
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14
+ """
15
+
16
+ # pylint: disable=line-too-long
17
+ # pylint: disable=missing-function-docstring
18
+ # pylint: disable=useless-suppression
19
+ # pylint: disable=inconsistent-quotes
20
+
21
+ import os
22
+ import glob
23
+ import importlib
24
+ import hashlib
25
+ import shutil
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ from torch.utils.file_baton import FileBaton
30
+ import torch.utils.cpp_extension
31
+
32
+ #----------------------------------------------------------------------------
33
+ # Global options.
34
+
35
+ verbosity = 'none' # Verbosity level: 'none', 'brief', 'full'
36
+
37
+ #----------------------------------------------------------------------------
38
+ # Internal helper funcs.
39
+
40
+ def _find_compiler_bindir():
41
+ patterns = [
42
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
43
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
44
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
45
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
46
+ ]
47
+ for pattern in patterns:
48
+ matches = sorted(glob.glob(pattern))
49
+ if len(matches):
50
+ return matches[-1]
51
+ return None
52
+
53
+ def _find_compiler_bindir_posix():
54
+ patterns = [
55
+ '/usr/local/cuda/bin'
56
+ ]
57
+ for pattern in patterns:
58
+ matches = sorted(glob.glob(pattern))
59
+ if len(matches):
60
+ return matches[-1]
61
+ return None
62
+
63
+ #----------------------------------------------------------------------------
64
+ # Main entry point for compiling and loading C++/CUDA plugins.
65
+
66
+ _cached_plugins = dict()
67
+
68
+ def get_plugin(module_name, sources, **build_kwargs):
69
+ assert verbosity in ['none', 'brief', 'full']
70
+
71
+ # Already cached?
72
+ if module_name in _cached_plugins:
73
+ return _cached_plugins[module_name]
74
+
75
+ # Print status.
76
+ if verbosity == 'full':
77
+ print(f'Setting up PyTorch plugin "{module_name}"...')
78
+ elif verbosity == 'brief':
79
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
80
+
81
+ try: # pylint: disable=too-many-nested-blocks
82
+ # Make sure we can find the necessary compiler binaries.
83
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
84
+ compiler_bindir = _find_compiler_bindir()
85
+ if compiler_bindir is None:
86
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
87
+ os.environ['PATH'] += ';' + compiler_bindir
88
+
89
+ elif os.name == 'posix':
90
+ compiler_bindir = _find_compiler_bindir_posix()
91
+ if compiler_bindir is None:
92
+ raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".')
93
+ os.environ['PATH'] += ';' + compiler_bindir
94
+
95
+ # Compile and load.
96
+ verbose_build = (verbosity == 'full')
97
+
98
+ # Incremental build md5sum trickery. Copies all the input source files
99
+ # into a cached build directory under a combined md5 digest of the input
100
+ # source files. Copying is done only if the combined digest has changed.
101
+ # This keeps input file timestamps and filenames the same as in previous
102
+ # extension builds, allowing for fast incremental rebuilds.
103
+ #
104
+ # This optimization is done only in case all the source files reside in
105
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
106
+ # environment variable is set (we take this as a signal that the user
107
+ # actually cares about this.)
108
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
109
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
110
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
111
+
112
+ # Compute a combined hash digest for all source files in the same
113
+ # custom op directory (usually .cu, .cpp, .py and .h files).
114
+ hash_md5 = hashlib.md5()
115
+ for src in all_source_files:
116
+ with open(src, 'rb') as f:
117
+ hash_md5.update(f.read())
118
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
119
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
120
+
121
+ if not os.path.isdir(digest_build_dir):
122
+ os.makedirs(digest_build_dir, exist_ok=True)
123
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
124
+ if baton.try_acquire():
125
+ try:
126
+ for src in all_source_files:
127
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
128
+ finally:
129
+ baton.release()
130
+ else:
131
+ # Someone else is copying source files under the digest dir,
132
+ # wait until done and continue.
133
+ baton.wait()
134
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
135
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
136
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
137
+ else:
138
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
139
+ module = importlib.import_module(module_name)
140
+
141
+ except:
142
+ if verbosity == 'brief':
143
+ print('Failed!')
144
+ raise
145
+
146
+ # Print status and add to cache.
147
+ if verbosity == 'full':
148
+ print(f'Done setting up PyTorch plugin "{module_name}".')
149
+ elif verbosity == 'brief':
150
+ print('Done.')
151
+ _cached_plugins[module_name] = module
152
+ return module
153
+
154
+ #----------------------------------------------------------------------------
155
+
156
+ # pylint: enable=line-too-long
157
+ # pylint: enable=missing-function-docstring
158
+ # pylint: enable=useless-suppression
159
+ # pylint: enable=inconsistent-quotes
third_party/stylegan2_official_ops/fma.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.
12
+
13
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14
+ """
15
+
16
+ # pylint: disable=line-too-long
17
+ # pylint: disable=missing-function-docstring
18
+
19
+ import torch
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def fma(a, b, c, impl='cuda'): # => a * b + c
24
+ if impl == 'cuda':
25
+ return _FusedMultiplyAdd.apply(a, b, c)
26
+ return torch.addcmul(c, a, b)
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
31
+ @staticmethod
32
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
33
+ out = torch.addcmul(c, a, b)
34
+ ctx.save_for_backward(a, b)
35
+ ctx.c_shape = c.shape
36
+ return out
37
+
38
+ @staticmethod
39
+ def backward(ctx, dout): # pylint: disable=arguments-differ
40
+ a, b = ctx.saved_tensors
41
+ c_shape = ctx.c_shape
42
+ da = None
43
+ db = None
44
+ dc = None
45
+
46
+ if ctx.needs_input_grad[0]:
47
+ da = _unbroadcast(dout * b, a.shape)
48
+
49
+ if ctx.needs_input_grad[1]:
50
+ db = _unbroadcast(dout * a, b.shape)
51
+
52
+ if ctx.needs_input_grad[2]:
53
+ dc = _unbroadcast(dout, c_shape)
54
+
55
+ return da, db, dc
56
+
57
+ #----------------------------------------------------------------------------
58
+
59
+ def _unbroadcast(x, shape):
60
+ extra_dims = x.ndim - len(shape)
61
+ assert extra_dims >= 0
62
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
63
+ if len(dim):
64
+ x = x.sum(dim=dim, keepdim=True)
65
+ if extra_dims:
66
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
67
+ assert x.shape == shape
68
+ return x
69
+
70
+ #----------------------------------------------------------------------------
71
+
72
+ # pylint: enable=line-too-long
73
+ # pylint: enable=missing-function-docstring
third_party/stylegan2_official_ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom replacement for `torch.nn.functional.grid_sample`.
12
+
13
+ This is useful for differentiable augmentation. This customized operator
14
+ supports arbitrarily high order gradients between the input and output. Only
15
+ works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and
16
+ `align_corners=False`.
17
+
18
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
19
+ """
20
+
21
+ # pylint: disable=redefined-builtin
22
+ # pylint: disable=arguments-differ
23
+ # pylint: disable=protected-access
24
+ # pylint: disable=line-too-long
25
+ # pylint: disable=missing-function-docstring
26
+
27
+ import warnings
28
+ import torch
29
+
30
+ #----------------------------------------------------------------------------
31
+
32
+ enabled = True # Enable the custom op by setting this to true.
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def grid_sample(input, grid, impl='cuda'):
37
+ if impl == 'cuda' and _should_use_custom_op():
38
+ return _GridSample2dForward.apply(input, grid)
39
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
40
+
41
+ #----------------------------------------------------------------------------
42
+
43
+ def _should_use_custom_op():
44
+ if not enabled:
45
+ return False
46
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
47
+ return True
48
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
49
+ return False
50
+
51
+ #----------------------------------------------------------------------------
52
+
53
+ class _GridSample2dForward(torch.autograd.Function):
54
+ @staticmethod
55
+ def forward(ctx, input, grid):
56
+ assert input.ndim == 4
57
+ assert grid.ndim == 4
58
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
59
+ ctx.save_for_backward(input, grid)
60
+ return output
61
+
62
+ @staticmethod
63
+ def backward(ctx, grad_output):
64
+ input, grid = ctx.saved_tensors
65
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
66
+ return grad_input, grad_grid
67
+
68
+ #----------------------------------------------------------------------------
69
+
70
+ class _GridSample2dBackward(torch.autograd.Function):
71
+ @staticmethod
72
+ def forward(ctx, grad_output, input, grid):
73
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
74
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
75
+ ctx.save_for_backward(grid)
76
+ return grad_input, grad_grid
77
+
78
+ @staticmethod
79
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
80
+ _ = grad2_grad_grid # unused
81
+ grid, = ctx.saved_tensors
82
+ grad2_grad_output = None
83
+ grad2_input = None
84
+ grad2_grid = None
85
+
86
+ if ctx.needs_input_grad[0]:
87
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
88
+
89
+ assert not ctx.needs_input_grad[2]
90
+ return grad2_grad_output, grad2_input, grad2_grid
91
+
92
+ #----------------------------------------------------------------------------
93
+
94
+ # pylint: enable=redefined-builtin
95
+ # pylint: enable=arguments-differ
96
+ # pylint: enable=protected-access
97
+ # pylint: enable=line-too-long
98
+ # pylint: enable=missing-function-docstring
third_party/stylegan2_official_ops/misc.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Misc functions for customized operations.
12
+
13
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14
+ """
15
+
16
+ # pylint: disable=line-too-long
17
+ # pylint: disable=missing-class-docstring
18
+ # pylint: disable=missing-function-docstring
19
+ # pylint: disable=use-maxsplit-arg
20
+ # pylint: disable=unnecessary-comprehension
21
+
22
+ import re
23
+ import contextlib
24
+ import warnings
25
+ from easydict import EasyDict
26
+ import numpy as np
27
+ import torch
28
+
29
+ #----------------------------------------------------------------------------
30
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
31
+ # same constant is used multiple times.
32
+
33
+ _constant_cache = dict()
34
+
35
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
36
+ value = np.asarray(value)
37
+ if shape is not None:
38
+ shape = tuple(shape)
39
+ if dtype is None:
40
+ dtype = torch.get_default_dtype()
41
+ if device is None:
42
+ device = torch.device('cpu')
43
+ if memory_format is None:
44
+ memory_format = torch.contiguous_format
45
+
46
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
47
+ tensor = _constant_cache.get(key, None)
48
+ if tensor is None:
49
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
50
+ if shape is not None:
51
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
52
+ tensor = tensor.contiguous(memory_format=memory_format)
53
+ _constant_cache[key] = tensor
54
+ return tensor
55
+
56
+ #----------------------------------------------------------------------------
57
+ # Replace NaN/Inf with specified numerical values.
58
+
59
+ try:
60
+ nan_to_num = torch.nan_to_num # 1.8.0a0
61
+ except AttributeError:
62
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
63
+ assert isinstance(input, torch.Tensor)
64
+ if posinf is None:
65
+ posinf = torch.finfo(input.dtype).max
66
+ if neginf is None:
67
+ neginf = torch.finfo(input.dtype).min
68
+ assert nan == 0
69
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
70
+
71
+ #----------------------------------------------------------------------------
72
+ # Symbolic assert.
73
+
74
+ try:
75
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
76
+ except AttributeError:
77
+ symbolic_assert = torch.Assert # 1.7.0
78
+
79
+ #----------------------------------------------------------------------------
80
+ # Context manager to suppress known warnings in torch.jit.trace().
81
+
82
+ class suppress_tracer_warnings(warnings.catch_warnings):
83
+ def __enter__(self):
84
+ super().__enter__()
85
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
86
+ return self
87
+
88
+ #----------------------------------------------------------------------------
89
+ # Assert that the shape of a tensor matches the given list of integers.
90
+ # None indicates that the size of a dimension is allowed to vary.
91
+ # Performs symbolic assertion when used in torch.jit.trace().
92
+
93
+ def assert_shape(tensor, ref_shape):
94
+ if tensor.ndim != len(ref_shape):
95
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
96
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
97
+ if ref_size is None:
98
+ pass
99
+ elif isinstance(ref_size, torch.Tensor):
100
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
101
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
102
+ elif isinstance(size, torch.Tensor):
103
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
104
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
105
+ elif size != ref_size:
106
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
107
+
108
+ #----------------------------------------------------------------------------
109
+ # Function decorator that calls torch.autograd.profiler.record_function().
110
+
111
+ def profiled_function(fn):
112
+ def decorator(*args, **kwargs):
113
+ with torch.autograd.profiler.record_function(fn.__name__):
114
+ return fn(*args, **kwargs)
115
+ decorator.__name__ = fn.__name__
116
+ return decorator
117
+
118
+ #----------------------------------------------------------------------------
119
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
120
+ # indefinitely, shuffling items as it goes.
121
+
122
+ class InfiniteSampler(torch.utils.data.Sampler):
123
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
124
+ assert len(dataset) > 0
125
+ assert num_replicas > 0
126
+ assert 0 <= rank < num_replicas
127
+ assert 0 <= window_size <= 1
128
+ super().__init__(dataset)
129
+ self.dataset = dataset
130
+ self.rank = rank
131
+ self.num_replicas = num_replicas
132
+ self.shuffle = shuffle
133
+ self.seed = seed
134
+ self.window_size = window_size
135
+
136
+ def __iter__(self):
137
+ order = np.arange(len(self.dataset))
138
+ rnd = None
139
+ window = 0
140
+ if self.shuffle:
141
+ rnd = np.random.RandomState(self.seed)
142
+ rnd.shuffle(order)
143
+ window = int(np.rint(order.size * self.window_size))
144
+
145
+ idx = 0
146
+ while True:
147
+ i = idx % order.size
148
+ if idx % self.num_replicas == self.rank:
149
+ yield order[i]
150
+ if window >= 2:
151
+ j = (i - rnd.randint(window)) % order.size
152
+ order[i], order[j] = order[j], order[i]
153
+ idx += 1
154
+
155
+ #----------------------------------------------------------------------------
156
+ # Utilities for operating with torch.nn.Module parameters and buffers.
157
+
158
+ def params_and_buffers(module):
159
+ assert isinstance(module, torch.nn.Module)
160
+ return list(module.parameters()) + list(module.buffers())
161
+
162
+ def named_params_and_buffers(module):
163
+ assert isinstance(module, torch.nn.Module)
164
+ return list(module.named_parameters()) + list(module.named_buffers())
165
+
166
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
167
+ assert isinstance(src_module, torch.nn.Module)
168
+ assert isinstance(dst_module, torch.nn.Module)
169
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
170
+ for name, tensor in named_params_and_buffers(dst_module):
171
+ assert (name in src_tensors) or (not require_all)
172
+ if name in src_tensors:
173
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
174
+
175
+ #----------------------------------------------------------------------------
176
+ # Context manager for easily enabling/disabling DistributedDataParallel
177
+ # synchronization.
178
+
179
+ @contextlib.contextmanager
180
+ def ddp_sync(module, sync):
181
+ assert isinstance(module, torch.nn.Module)
182
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
183
+ yield
184
+ else:
185
+ with module.no_sync():
186
+ yield
187
+
188
+ #----------------------------------------------------------------------------
189
+ # Check DistributedDataParallel consistency across processes.
190
+
191
+ def check_ddp_consistency(module, ignore_regex=None):
192
+ assert isinstance(module, torch.nn.Module)
193
+ for name, tensor in named_params_and_buffers(module):
194
+ fullname = type(module).__name__ + '.' + name
195
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
196
+ continue
197
+ tensor = tensor.detach()
198
+ other = tensor.clone()
199
+ torch.distributed.broadcast(tensor=other, src=0)
200
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
201
+
202
+ #----------------------------------------------------------------------------
203
+ # Print summary table of module hierarchy.
204
+
205
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
206
+ assert isinstance(module, torch.nn.Module)
207
+ assert not isinstance(module, torch.jit.ScriptModule)
208
+ assert isinstance(inputs, (tuple, list))
209
+
210
+ # Register hooks.
211
+ entries = []
212
+ nesting = [0]
213
+ def pre_hook(_mod, _inputs):
214
+ nesting[0] += 1
215
+ def post_hook(mod, _inputs, outputs):
216
+ nesting[0] -= 1
217
+ if nesting[0] <= max_nesting:
218
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
219
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
220
+ entries.append(EasyDict(mod=mod, outputs=outputs))
221
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
222
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
223
+
224
+ # Run module.
225
+ outputs = module(*inputs)
226
+ for hook in hooks:
227
+ hook.remove()
228
+
229
+ # Identify unique outputs, parameters, and buffers.
230
+ tensors_seen = set()
231
+ for e in entries:
232
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
233
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
234
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
235
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
236
+
237
+ # Filter out redundant entries.
238
+ if skip_redundant:
239
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
240
+
241
+ # Construct table.
242
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
243
+ rows += [['---'] * len(rows[0])]
244
+ param_total = 0
245
+ buffer_total = 0
246
+ submodule_names = {mod: name for name, mod in module.named_modules()}
247
+ for e in entries:
248
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
249
+ param_size = sum(t.numel() for t in e.unique_params)
250
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
251
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
252
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
253
+ rows += [[
254
+ name + (':0' if len(e.outputs) >= 2 else ''),
255
+ str(param_size) if param_size else '-',
256
+ str(buffer_size) if buffer_size else '-',
257
+ (output_shapes + ['-'])[0],
258
+ (output_dtypes + ['-'])[0],
259
+ ]]
260
+ for idx in range(1, len(e.outputs)):
261
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
262
+ param_total += param_size
263
+ buffer_total += buffer_size
264
+ rows += [['---'] * len(rows[0])]
265
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
266
+
267
+ # Print table.
268
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
269
+ print()
270
+ for row in rows:
271
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
272
+ print()
273
+ return outputs
274
+
275
+ #----------------------------------------------------------------------------
276
+
277
+ # pylint: enable=line-too-long
278
+ # pylint: enable=missing-class-docstring
279
+ # pylint: enable=missing-function-docstring
280
+ # pylint: enable=use-maxsplit-arg
281
+ # pylint: enable=unnecessary-comprehension
third_party/stylegan2_official_ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29
+
30
+ // Create output tensor.
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37
+
38
+ // Initialize CUDA kernel parameters.
39
+ upfirdn2d_kernel_params p;
40
+ p.x = x.data_ptr();
41
+ p.f = f.data_ptr<float>();
42
+ p.y = y.data_ptr();
43
+ p.up = make_int2(upx, upy);
44
+ p.down = make_int2(downx, downy);
45
+ p.pad0 = make_int2(padx0, pady0);
46
+ p.flip = (flip) ? 1 : 0;
47
+ p.gain = gain;
48
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56
+
57
+ // Choose CUDA kernel.
58
+ upfirdn2d_kernel_spec spec;
59
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60
+ {
61
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
62
+ });
63
+
64
+ // Set looping options.
65
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66
+ p.loopMinor = spec.loopMinor;
67
+ p.loopX = spec.loopX;
68
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70
+
71
+ // Compute grid size.
72
+ dim3 blockSize, gridSize;
73
+ if (spec.tileOutW < 0) // large
74
+ {
75
+ blockSize = dim3(4, 32, 1);
76
+ gridSize = dim3(
77
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79
+ p.launchMajor);
80
+ }
81
+ else // small
82
+ {
83
+ blockSize = dim3(256, 1, 1);
84
+ gridSize = dim3(
85
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87
+ p.launchMajor);
88
+ }
89
+
90
+ // Launch CUDA kernel.
91
+ void* args[] = {&p};
92
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
+ return y;
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+
98
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
+ {
100
+ m.def("upfirdn2d", &upfirdn2d);
101
+ }
102
+
103
+ //------------------------------------------------------------------------
third_party/stylegan2_official_ops/upfirdn2d.cu ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+
209
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
210
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
211
+
212
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
213
+ {
214
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
215
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
216
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
217
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
218
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
219
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
220
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
221
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
222
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
223
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
224
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
225
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
226
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
228
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
229
+ }
230
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
231
+ {
232
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
233
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
234
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
238
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
239
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
240
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
241
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
242
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
243
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
244
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
245
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
246
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
247
+ }
248
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
249
+ {
250
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ }
255
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
256
+ {
257
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
263
+ {
264
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
265
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
266
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
268
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
269
+ }
270
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
271
+ {
272
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
273
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
274
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
275
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
276
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
277
+ }
278
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
279
+ {
280
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
281
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
282
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
283
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
284
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
285
+ }
286
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
287
+ {
288
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
289
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
290
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
291
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
292
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
293
+ }
294
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
295
+ {
296
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
297
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
298
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
299
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
300
+ }
301
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
302
+ {
303
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
304
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
305
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
306
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
307
+ }
308
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
309
+ {
310
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
311
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
312
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
313
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
314
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
315
+ }
316
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
317
+ {
318
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
319
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
320
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
321
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
322
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
323
+ }
324
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
325
+ {
326
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
327
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
328
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
329
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
330
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
331
+ }
332
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
333
+ {
334
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
335
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
336
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
337
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
338
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
339
+ }
340
+ return spec;
341
+ }
342
+
343
+ //------------------------------------------------------------------------
344
+ // Template specializations.
345
+
346
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
347
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
348
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
349
+
350
+ //------------------------------------------------------------------------
third_party/stylegan2_official_ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
third_party/stylegan2_official_ops/upfirdn2d.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom operators for efficient resampling of 2D images.
12
+
13
+ `upfirdn` means executing upsampling, FIR filtering, downsampling in sequence.
14
+
15
+ Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
16
+ """
17
+
18
+ # pylint: disable=line-too-long
19
+ # pylint: disable=missing-class-docstring
20
+ # pylint: disable=global-variable-not-assigned
21
+ # pylint: disable=bare-except
22
+
23
+ import os
24
+ import warnings
25
+ import traceback
26
+ import numpy as np
27
+ import torch
28
+
29
+ from . import custom_ops
30
+ from . import misc
31
+ from . import conv2d_gradfix
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ _inited = False
36
+ _plugin = None
37
+
38
+ def _init():
39
+ global _inited, _plugin
40
+ if not _inited:
41
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
42
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
43
+ try:
44
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
45
+ except:
46
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
47
+ return _plugin is not None
48
+
49
+ def _parse_scaling(scaling):
50
+ if isinstance(scaling, int):
51
+ scaling = [scaling, scaling]
52
+ assert isinstance(scaling, (list, tuple))
53
+ assert all(isinstance(x, int) for x in scaling)
54
+ sx, sy = scaling
55
+ assert sx >= 1 and sy >= 1
56
+ return sx, sy
57
+
58
+ def _parse_padding(padding):
59
+ if isinstance(padding, int):
60
+ padding = [padding, padding]
61
+ assert isinstance(padding, (list, tuple))
62
+ assert all(isinstance(x, int) for x in padding)
63
+ if len(padding) == 2:
64
+ padx, pady = padding
65
+ padding = [padx, padx, pady, pady]
66
+ padx0, padx1, pady0, pady1 = padding
67
+ return padx0, padx1, pady0, pady1
68
+
69
+ def _get_filter_size(f):
70
+ if f is None:
71
+ return 1, 1
72
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
73
+ fw = f.shape[-1]
74
+ fh = f.shape[0]
75
+ with misc.suppress_tracer_warnings():
76
+ fw = int(fw)
77
+ fh = int(fh)
78
+ misc.assert_shape(f, [fh, fw][:f.ndim])
79
+ assert fw >= 1 and fh >= 1
80
+ return fw, fh
81
+
82
+ #----------------------------------------------------------------------------
83
+
84
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
85
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
86
+
87
+ Args:
88
+ f: Torch tensor, numpy array, or python list of the shape
89
+ `[filter_height, filter_width]` (non-separable),
90
+ `[filter_taps]` (separable),
91
+ `[]` (impulse), or
92
+ `None` (identity).
93
+ device: Result device (default: cpu).
94
+ normalize: Normalize the filter so that it retains the magnitude
95
+ for constant input signal (DC)? (default: True).
96
+ flip_filter: Flip the filter? (default: False).
97
+ gain: Overall scaling factor for signal magnitude (default: 1).
98
+ separable: Return a separable filter? (default: select automatically).
99
+
100
+ Returns:
101
+ Float32 tensor of the shape
102
+ `[filter_height, filter_width]` (non-separable) or
103
+ `[filter_taps]` (separable).
104
+ """
105
+ # Validate.
106
+ if f is None:
107
+ f = 1
108
+ f = torch.as_tensor(f, dtype=torch.float32)
109
+ assert f.ndim in [0, 1, 2]
110
+ assert f.numel() > 0
111
+ if f.ndim == 0:
112
+ f = f[np.newaxis]
113
+
114
+ # Separable?
115
+ if separable is None:
116
+ separable = (f.ndim == 1 and f.numel() >= 8)
117
+ if f.ndim == 1 and not separable:
118
+ f = f.ger(f)
119
+ assert f.ndim == (1 if separable else 2)
120
+
121
+ # Apply normalize, flip, gain, and device.
122
+ if normalize:
123
+ f /= f.sum()
124
+ if flip_filter:
125
+ f = f.flip(list(range(f.ndim)))
126
+ f = f * (gain ** (f.ndim / 2))
127
+ f = f.to(device=device)
128
+ return f
129
+
130
+ #----------------------------------------------------------------------------
131
+
132
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
133
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
134
+
135
+ Performs the following sequence of operations for each channel:
136
+
137
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
138
+
139
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
140
+ Negative padding corresponds to cropping the image.
141
+
142
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
143
+ so that the footprint of all output pixels lies within the input image.
144
+
145
+ 4. Downsample the image by keeping every Nth pixel (`down`).
146
+
147
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
148
+ The fused op is considerably more efficient than performing the same calculation
149
+ using standard PyTorch ops. It supports gradients of arbitrary order.
150
+
151
+ Args:
152
+ x: Float32/float64/float16 input tensor of the shape
153
+ `[batch_size, num_channels, in_height, in_width]`.
154
+ f: Float32 FIR filter of the shape
155
+ `[filter_height, filter_width]` (non-separable),
156
+ `[filter_taps]` (separable), or
157
+ `None` (identity).
158
+ up: Integer upsampling factor. Can be a single int or a list/tuple
159
+ `[x, y]` (default: 1).
160
+ down: Integer downsampling factor. Can be a single int or a list/tuple
161
+ `[x, y]` (default: 1).
162
+ padding: Padding with respect to the upsampled image. Can be a single number
163
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
164
+ (default: 0).
165
+ flip_filter: False = convolution, True = correlation (default: False).
166
+ gain: Overall scaling factor for signal magnitude (default: 1).
167
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
168
+
169
+ Returns:
170
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
171
+ """
172
+ assert isinstance(x, torch.Tensor)
173
+ assert impl in ['ref', 'cuda']
174
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
175
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
176
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
177
+
178
+ #----------------------------------------------------------------------------
179
+
180
+ @misc.profiled_function
181
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
182
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
183
+ """
184
+ # Validate arguments.
185
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
186
+ if f is None:
187
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
188
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
189
+ assert f.dtype == torch.float32 and not f.requires_grad
190
+ batch_size, num_channels, in_height, in_width = x.shape
191
+ upx, upy = _parse_scaling(up)
192
+ downx, downy = _parse_scaling(down)
193
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
194
+
195
+ # Upsample by inserting zeros.
196
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
197
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
198
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
199
+
200
+ # Pad or crop.
201
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
202
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
203
+
204
+ # Setup filter.
205
+ f = f * (gain ** (f.ndim / 2))
206
+ f = f.to(x.dtype)
207
+ if not flip_filter:
208
+ f = f.flip(list(range(f.ndim)))
209
+
210
+ # Convolve with the filter.
211
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
212
+ if f.ndim == 4:
213
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
214
+ else:
215
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
216
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
217
+
218
+ # Downsample by throwing away pixels.
219
+ x = x[:, :, ::downy, ::downx]
220
+ return x
221
+
222
+ #----------------------------------------------------------------------------
223
+
224
+ _upfirdn2d_cuda_cache = dict()
225
+
226
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
227
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
228
+ """
229
+ # Parse arguments.
230
+ upx, upy = _parse_scaling(up)
231
+ downx, downy = _parse_scaling(down)
232
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
233
+
234
+ # Lookup from cache.
235
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
236
+ if key in _upfirdn2d_cuda_cache:
237
+ return _upfirdn2d_cuda_cache[key]
238
+
239
+ # Forward op.
240
+ class Upfirdn2dCuda(torch.autograd.Function):
241
+ @staticmethod
242
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
243
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
244
+ if f is None:
245
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
246
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
247
+ y = x
248
+ if f.ndim == 2:
249
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
250
+ else:
251
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
252
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
253
+ ctx.save_for_backward(f)
254
+ ctx.x_shape = x.shape
255
+ return y
256
+
257
+ @staticmethod
258
+ def backward(ctx, dy): # pylint: disable=arguments-differ
259
+ f, = ctx.saved_tensors
260
+ _, _, ih, iw = ctx.x_shape
261
+ _, _, oh, ow = dy.shape
262
+ fw, fh = _get_filter_size(f)
263
+ p = [
264
+ fw - padx0 - 1,
265
+ iw * upx - ow * downx + padx0 - upx + 1,
266
+ fh - pady0 - 1,
267
+ ih * upy - oh * downy + pady0 - upy + 1,
268
+ ]
269
+ dx = None
270
+ df = None
271
+
272
+ if ctx.needs_input_grad[0]:
273
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
274
+
275
+ assert not ctx.needs_input_grad[1]
276
+ return dx, df
277
+
278
+ # Add to cache.
279
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
280
+ return Upfirdn2dCuda
281
+
282
+ #----------------------------------------------------------------------------
283
+
284
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
285
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
286
+
287
+ By default, the result is padded so that its shape matches the input.
288
+ User-specified padding is applied on top of that, with negative values
289
+ indicating cropping. Pixels outside the image are assumed to be zero.
290
+
291
+ Args:
292
+ x: Float32/float64/float16 input tensor of the shape
293
+ `[batch_size, num_channels, in_height, in_width]`.
294
+ f: Float32 FIR filter of the shape
295
+ `[filter_height, filter_width]` (non-separable),
296
+ `[filter_taps]` (separable), or
297
+ `None` (identity).
298
+ padding: Padding with respect to the output. Can be a single number or a
299
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
300
+ (default: 0).
301
+ flip_filter: False = convolution, True = correlation (default: False).
302
+ gain: Overall scaling factor for signal magnitude (default: 1).
303
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
304
+
305
+ Returns:
306
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
307
+ """
308
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
309
+ fw, fh = _get_filter_size(f)
310
+ p = [
311
+ padx0 + fw // 2,
312
+ padx1 + (fw - 1) // 2,
313
+ pady0 + fh // 2,
314
+ pady1 + (fh - 1) // 2,
315
+ ]
316
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
317
+
318
+ #----------------------------------------------------------------------------
319
+
320
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
321
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
322
+
323
+ By default, the result is padded so that its shape is a multiple of the input.
324
+ User-specified padding is applied on top of that, with negative values
325
+ indicating cropping. Pixels outside the image are assumed to be zero.
326
+
327
+ Args:
328
+ x: Float32/float64/float16 input tensor of the shape
329
+ `[batch_size, num_channels, in_height, in_width]`.
330
+ f: Float32 FIR filter of the shape
331
+ `[filter_height, filter_width]` (non-separable),
332
+ `[filter_taps]` (separable), or
333
+ `None` (identity).
334
+ up: Integer upsampling factor. Can be a single int or a list/tuple
335
+ `[x, y]` (default: 1).
336
+ padding: Padding with respect to the output. Can be a single number or a
337
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
338
+ (default: 0).
339
+ flip_filter: False = convolution, True = correlation (default: False).
340
+ gain: Overall scaling factor for signal magnitude (default: 1).
341
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
342
+
343
+ Returns:
344
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
345
+ """
346
+ upx, upy = _parse_scaling(up)
347
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
348
+ fw, fh = _get_filter_size(f)
349
+ p = [
350
+ padx0 + (fw + upx - 1) // 2,
351
+ padx1 + (fw - upx) // 2,
352
+ pady0 + (fh + upy - 1) // 2,
353
+ pady1 + (fh - upy) // 2,
354
+ ]
355
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
356
+
357
+ #----------------------------------------------------------------------------
358
+
359
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
360
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
361
+
362
+ By default, the result is padded so that its shape is a fraction of the input.
363
+ User-specified padding is applied on top of that, with negative values
364
+ indicating cropping. Pixels outside the image are assumed to be zero.
365
+
366
+ Args:
367
+ x: Float32/float64/float16 input tensor of the shape
368
+ `[batch_size, num_channels, in_height, in_width]`.
369
+ f: Float32 FIR filter of the shape
370
+ `[filter_height, filter_width]` (non-separable),
371
+ `[filter_taps]` (separable), or
372
+ `None` (identity).
373
+ down: Integer downsampling factor. Can be a single int or a list/tuple
374
+ `[x, y]` (default: 1).
375
+ padding: Padding with respect to the input. Can be a single number or a
376
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
377
+ (default: 0).
378
+ flip_filter: False = convolution, True = correlation (default: False).
379
+ gain: Overall scaling factor for signal magnitude (default: 1).
380
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
381
+
382
+ Returns:
383
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
384
+ """
385
+ downx, downy = _parse_scaling(down)
386
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
387
+ fw, fh = _get_filter_size(f)
388
+ p = [
389
+ padx0 + (fw - downx + 1) // 2,
390
+ padx1 + (fw - downx) // 2,
391
+ pady0 + (fh - downy + 1) // 2,
392
+ pady1 + (fh - downy) // 2,
393
+ ]
394
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
395
+
396
+ #----------------------------------------------------------------------------
397
+
398
+ # pylint: enable=line-too-long
399
+ # pylint: enable=missing-class-docstring
400
+ # pylint: enable=global-variable-not-assigned
401
+ # pylint: enable=bare-except
third_party/stylegan3_official_ops/README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Operators for StyleGAN2
2
+
3
+ All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
4
+
5
+ - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
6
+ - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
7
+ - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
8
+ - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
9
+ - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
10
+ - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
11
+ - `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing.
12
+ - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
13
+ - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
14
+ - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
15
+
16
+ We make following slight modifications beyond disabling some lint warnings:
17
+
18
+ - Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
19
+ - Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators.
20
+ - Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
21
+ - Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
22
+ - Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`.
23
+ - Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
24
+ - Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
25
+ - Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
26
+ - Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
27
+ - Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
28
+ - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
29
+
30
+ Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.