add files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- compute_direction.py +96 -0
- compute_jacobian.py +200 -0
- coordinate.py +142 -0
- directions/.DS_Store +0 -0
- directions/afhq/.DS_Store +0 -0
- directions/afhq/stylegan3/eyes-r.npy +3 -0
- directions/ffhq/stylegan2/eyebrows.npy +3 -0
- directions/ffhq/stylegan2/eyesize.npy +3 -0
- directions/ffhq/stylegan2/gaze_direction.npy +3 -0
- directions/ffhq/stylegan2/lipstick.npy +3 -0
- directions/ffhq/stylegan2/mouth.npy +3 -0
- directions/ffhq/stylegan2/nose_length.npy +3 -0
- directions/ffhq/stylegan3/eyes-r.npy +3 -0
- manipulate.py +253 -0
- models/__init__.py +45 -0
- models/ghfeat_encoder.py +563 -0
- models/inception_model.py +562 -0
- models/perceptual_model.py +519 -0
- models/pggan_discriminator.py +465 -0
- models/pggan_generator.py +401 -0
- models/stylegan2_discriminator.py +729 -0
- models/stylegan2_generator.py +1394 -0
- models/stylegan3_generator.py +1332 -0
- models/stylegan_discriminator.py +624 -0
- models/stylegan_generator.py +999 -0
- models/test.py +146 -0
- models/utils/__init__.py +0 -0
- models/utils/ops.py +18 -0
- requirements/convert.txt +11 -0
- requirements/develop.txt +3 -0
- requirements/minimal.txt +21 -0
- synthesis.py +178 -0
- third_party/__init__.py +0 -0
- third_party/stylegan2_official_ops/README.md +28 -0
- third_party/stylegan2_official_ops/__init__.py +0 -0
- third_party/stylegan2_official_ops/bias_act.cpp +99 -0
- third_party/stylegan2_official_ops/bias_act.cu +173 -0
- third_party/stylegan2_official_ops/bias_act.h +38 -0
- third_party/stylegan2_official_ops/bias_act.py +227 -0
- third_party/stylegan2_official_ops/conv2d_gradfix.py +189 -0
- third_party/stylegan2_official_ops/conv2d_resample.py +168 -0
- third_party/stylegan2_official_ops/custom_ops.py +159 -0
- third_party/stylegan2_official_ops/fma.py +73 -0
- third_party/stylegan2_official_ops/grid_sample_gradfix.py +98 -0
- third_party/stylegan2_official_ops/misc.py +281 -0
- third_party/stylegan2_official_ops/upfirdn2d.cpp +103 -0
- third_party/stylegan2_official_ops/upfirdn2d.cu +350 -0
- third_party/stylegan2_official_ops/upfirdn2d.h +59 -0
- third_party/stylegan2_official_ops/upfirdn2d.py +401 -0
- third_party/stylegan3_official_ops/README.md +30 -0
compute_direction.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Computes the semantic directions regarding a specific image region."""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from coordinate import COORDINATES
|
10 |
+
from coordinate import get_mask
|
11 |
+
from utils.image_utils import save_image
|
12 |
+
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
"""Parses arguments."""
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument('jaco_path', type=str,
|
19 |
+
help='Path to jacobian matrix.')
|
20 |
+
parser.add_argument('--region', type=str, default='eyes',
|
21 |
+
help='The region to be used to compute jacobian.')
|
22 |
+
parser.add_argument('--save_dir', type=str, default='',
|
23 |
+
help='Directory to save the results. If not specified,'
|
24 |
+
'the results will be saved to '
|
25 |
+
'`work_dirs/{TASK_SPECIFIC}/` by default')
|
26 |
+
parser.add_argument('--job', type=str, default='directions',
|
27 |
+
help='Name for the job (default: directions)')
|
28 |
+
parser.add_argument('--name', type=str, default='resefa',
|
29 |
+
help='Name of help save the results.')
|
30 |
+
parser.add_argument('--data_name', type=str, default='ffhq',
|
31 |
+
help='Name of the dataset.')
|
32 |
+
parser.add_argument('--full_rank', action='store_true',
|
33 |
+
help='Whether or not to full rank background'
|
34 |
+
' (default: False).')
|
35 |
+
parser.add_argument('--tao', type=float, default=1e-3,
|
36 |
+
help='Coefficient to the identity matrix '
|
37 |
+
'(default: 1e-3).')
|
38 |
+
return parser.parse_args()
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
"""Main function."""
|
43 |
+
args = parse_args()
|
44 |
+
assert os.path.exists(args.jaco_path)
|
45 |
+
Jacobians = np.load(args.jaco_path)
|
46 |
+
image_size = Jacobians.shape[2]
|
47 |
+
w_dim = Jacobians.shape[-1]
|
48 |
+
coord_dict = COORDINATES[args.data_name]
|
49 |
+
assert args.region in coord_dict, \
|
50 |
+
f'{args.region} coordinate is not defined in ' \
|
51 |
+
f'COORDINATE_{args.data_name}. Please define this region first!'
|
52 |
+
coords = coord_dict[args.region]
|
53 |
+
mask = get_mask(image_size, coordinate=coords)
|
54 |
+
foreground_ind = np.where(mask == 1)
|
55 |
+
background_ind = np.where((1 - mask) == 1)
|
56 |
+
temp_dir = f'./work_dirs/{args.job}/{args.data_name}/{args.region}'
|
57 |
+
save_dir = args.save_dir or temp_dir
|
58 |
+
os.makedirs(save_dir, exist_ok=True)
|
59 |
+
for ind in tqdm(range(Jacobians.shape[0])):
|
60 |
+
Jacobian = Jacobians[ind]
|
61 |
+
if len(Jacobian.shape) == 4: # [H, W, 1, latent_dim]
|
62 |
+
Jaco_fore = Jacobian[foreground_ind[0], foreground_ind[1], 0]
|
63 |
+
Jaco_back = Jacobian[background_ind[0], background_ind[1], 0]
|
64 |
+
elif len(Jacobian.shape) == 5: # [channel, H, W, 1, latent_dim]
|
65 |
+
Jaco_fore = Jacobian[:, foreground_ind[0], foreground_ind[1], 0]
|
66 |
+
Jaco_back = Jacobian[:, background_ind[0], background_ind[1], 0]
|
67 |
+
else:
|
68 |
+
raise ValueError('Shape of the Jacobian is not correct!')
|
69 |
+
Jaco_fore = np.reshape(Jaco_fore, [-1, w_dim])
|
70 |
+
Jaco_back = np.reshape(Jaco_back, [-1, w_dim])
|
71 |
+
coef_f = 1 / Jaco_fore.shape[0]
|
72 |
+
coef_b = 1 / Jaco_back.shape[0]
|
73 |
+
M_fore = coef_f * Jaco_fore.T.dot(Jaco_fore)
|
74 |
+
M_back = coef_b * Jaco_back.T.dot(Jaco_back)
|
75 |
+
if args.full_rank:
|
76 |
+
# J = J_b^TJ_b
|
77 |
+
# J = (J + tao * trace(J) * I)
|
78 |
+
print('Using full rank')
|
79 |
+
coef = args.tao * np.trace(M_back)
|
80 |
+
M_back = M_back + coef * np.identity(M_back.shape[0])
|
81 |
+
# inv(B) * A = lambda x
|
82 |
+
temp = np.linalg.inv(M_back).dot(M_fore)
|
83 |
+
eig_val, eig_vec = np.linalg.eig(temp)
|
84 |
+
eig_val = np.real(eig_val)
|
85 |
+
eig_vec = np.real(eig_vec)
|
86 |
+
directions = eig_vec.T
|
87 |
+
directions = directions[np.argsort(-eig_val)]
|
88 |
+
save_name = f'{save_dir}/image_{ind:02d}_region_{args.region}' \
|
89 |
+
f'_name_{args.name}'
|
90 |
+
np.save(f'{save_name}.npy', directions)
|
91 |
+
mask_i = np.tile(mask[:, :, np.newaxis], [1, 1, 3]) * 255
|
92 |
+
save_image(f'{save_name}_mask.png', mask_i.astype(np.uint8))
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
main()
|
compute_jacobian.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Functions to compute Jacobian based on pre-trained GAN generator.
|
3 |
+
|
4 |
+
Support StyleGAN2 or StyleGAN3
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import warnings
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.autograd.functional import jacobian
|
16 |
+
from models import build_model
|
17 |
+
from utils.image_utils import save_image
|
18 |
+
from utils.image_utils import postprocess_image
|
19 |
+
from utils.custom_utils import to_numpy
|
20 |
+
|
21 |
+
|
22 |
+
warnings.filterwarnings(action='ignore', category=UserWarning)
|
23 |
+
|
24 |
+
|
25 |
+
def parse_args():
|
26 |
+
"""Parses arguments."""
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
group = parser.add_argument_group('General options.')
|
29 |
+
group.add_argument('weight_path', type=str,
|
30 |
+
help='Weight path to the pre-trained model.')
|
31 |
+
group.add_argument('--save_dir', type=str, default=None,
|
32 |
+
help='Directory to save the results. If not specified, '
|
33 |
+
'the results will be saved to '
|
34 |
+
'`work_dirs/{TASK_SPECIFIC}/` by default.')
|
35 |
+
group.add_argument('--job', type=str, default='jacobians',
|
36 |
+
help='Name for the job (default: jacobians)')
|
37 |
+
group.add_argument('--seed', type=int, default=4,
|
38 |
+
help='Seed for sampling. (default: 4)')
|
39 |
+
group.add_argument('--nums', type=int, default=5,
|
40 |
+
help='Number of samples to synthesized. (default: 5)')
|
41 |
+
group.add_argument('--img_size', type=int, default=1024,
|
42 |
+
help='Size of the synthesized images. (default: 1024)')
|
43 |
+
group.add_argument('--w_dim', type=int, default=512,
|
44 |
+
help='Dimension of the latent w. (default: 512)')
|
45 |
+
group.add_argument('--save_jpg', action='store_false',
|
46 |
+
help='Whether to save the images used to compute '
|
47 |
+
'jacobians. (default: True)')
|
48 |
+
group.add_argument('-d', '--data_name', type=str, default='ffhq',
|
49 |
+
help='Name of the datasets. (default: ffhq)')
|
50 |
+
group.add_argument('--latent_path', type=str, default='',
|
51 |
+
help='Path to the given latent codes. (default: None)')
|
52 |
+
|
53 |
+
group = parser.add_argument_group('StyleGAN2')
|
54 |
+
group.add_argument('--stylegan2', action='store_true',
|
55 |
+
help='Whether or not using StyleGAN2. (default: False)')
|
56 |
+
group.add_argument('--scale_stylegan2', type=float, default=1.0,
|
57 |
+
help='Scale for the number of channel fro stylegan2.')
|
58 |
+
group.add_argument('--randomize_noise', type=str, default='const',
|
59 |
+
help='Noise type when computing. (const or random)')
|
60 |
+
|
61 |
+
group = parser.add_argument_group('StyleGAN3')
|
62 |
+
group.add_argument('--stylegan3', action='store_true',
|
63 |
+
help='Whether or not using StyleGAN3. (default: False)')
|
64 |
+
group.add_argument('--cfg', type=str, default='T',
|
65 |
+
help='Config of the stylegan3 (T/R).')
|
66 |
+
group.add_argument('--scale_stylegan3r', type=float, default=2.0,
|
67 |
+
help='Scale for the number of channel for stylegan3 R.')
|
68 |
+
group.add_argument('--scale_stylegan3t', type=float, default=1.0,
|
69 |
+
help='Scale for the number of channel for stylegan3 T.')
|
70 |
+
group.add_argument('--tx', type=float, default=0,
|
71 |
+
help='Translate X-coordinate. (default: 0.0)')
|
72 |
+
group.add_argument('--ty', type=float, default=0,
|
73 |
+
help='Translate Y-coordinate. (default: 0.0)')
|
74 |
+
group.add_argument('--rotate', type=float, default=0,
|
75 |
+
help='Rotation angle in degrees. (default: 0)')
|
76 |
+
|
77 |
+
group = parser.add_argument_group('Jacobians')
|
78 |
+
group.add_argument('--b', type=float, default=1e-3,
|
79 |
+
help='Constant when computing jacobians fast.')
|
80 |
+
group.add_argument('--batch_size', type=int, default=4,
|
81 |
+
help='Batch size. (default: 4)')
|
82 |
+
return parser.parse_args()
|
83 |
+
|
84 |
+
|
85 |
+
def main():
|
86 |
+
"""Main function."""
|
87 |
+
args = parse_args()
|
88 |
+
# Parse model configuration.
|
89 |
+
assert (args.stylegan2 and not args.stylegan3) or \
|
90 |
+
(not args.stylegan2 and args.stylegan3)
|
91 |
+
job_disc = ''
|
92 |
+
if args.stylegan2:
|
93 |
+
config = dict(model_type='StyleGAN2Generator',
|
94 |
+
resolution=args.img_size,
|
95 |
+
w_dim=args.w_dim,
|
96 |
+
fmaps_base=int(args.scale_stylegan2 * (32 << 10)),
|
97 |
+
fmaps_max=512,)
|
98 |
+
job_disc += 'stylegan2'
|
99 |
+
else:
|
100 |
+
if args.stylegan3 and args.cfg == 'R':
|
101 |
+
config = dict(model_type='StyleGAN3Generator',
|
102 |
+
resolution=args.img_size,
|
103 |
+
w_dim=args.w_dim,
|
104 |
+
fmaps_base=int(args.scale_stylegan3r * (32 << 10)),
|
105 |
+
fmaps_max=1024,
|
106 |
+
use_radial_filter=True,)
|
107 |
+
job_disc += 'stylegan3r'
|
108 |
+
elif args.stylegan3 and args.cfg == 'T':
|
109 |
+
config = dict(model_type='StyleGAN3Generator',
|
110 |
+
resolution=args.img_size,
|
111 |
+
w_dim=args.w_dim,
|
112 |
+
fmaps_base=int(args.scale_stylegan3t * (32 << 10)),
|
113 |
+
fmaps_max=512,
|
114 |
+
use_radial_filter=False,
|
115 |
+
kernel_size=3,)
|
116 |
+
job_disc += 'stylegan3t'
|
117 |
+
else:
|
118 |
+
raise TypeError(f'StyleGAN3 config type error, need `R/T`,'
|
119 |
+
f' but got {args.cfg}')
|
120 |
+
job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}'
|
121 |
+
temp_dir = f'work_dirs/{args.job}/{args.data_name}/{job_name}'
|
122 |
+
save_dir = args.save_dir or temp_dir
|
123 |
+
os.makedirs(save_dir, exist_ok=True)
|
124 |
+
if args.save_jpg:
|
125 |
+
os.makedirs(f'{save_dir}/images', exist_ok=True)
|
126 |
+
|
127 |
+
print('Building generator...')
|
128 |
+
generator = build_model(**config)
|
129 |
+
checkpoint_path = args.weight_path
|
130 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
131 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
|
132 |
+
if 'generator_smooth' in checkpoint:
|
133 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
134 |
+
else:
|
135 |
+
generator.load_state_dict(checkpoint['generator'])
|
136 |
+
generator = generator.eval().cuda()
|
137 |
+
print('Finish loading checkpoint.')
|
138 |
+
|
139 |
+
# Set random seed.
|
140 |
+
np.random.seed(args.seed)
|
141 |
+
torch.manual_seed(args.seed)
|
142 |
+
if os.path.exists(args.latent_path):
|
143 |
+
latent_zs = np.load(args.latent_path)
|
144 |
+
latent_zs = latent_zs[:args.nums]
|
145 |
+
else:
|
146 |
+
latent_zs = np.random.randn(args.nums, generator.z_dim)
|
147 |
+
latent_zs = torch.from_numpy(latent_zs.astype(np.float32))
|
148 |
+
latent_zs = latent_zs.cuda()
|
149 |
+
with torch.no_grad():
|
150 |
+
latent_ws = generator.mapping(latent_zs)['w']
|
151 |
+
print(f'Shape of the latent w: {latent_ws.shape}')
|
152 |
+
|
153 |
+
def syn2jaco(w):
|
154 |
+
"""Wrap the synthesized function to compute the Jacobian easily.
|
155 |
+
|
156 |
+
Basically, this function defines a generator that takes the input
|
157 |
+
from the W space and then synthesizes an image. If the image is
|
158 |
+
larger than 256, it will be resized to 256 to save the time and
|
159 |
+
storage.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
w: latent code from the W space
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
An image with the size of [1, 256, 256]
|
166 |
+
"""
|
167 |
+
wp = w.unsqueeze(1).repeat((1, generator.num_layers, 1))
|
168 |
+
image = generator.synthesis(wp)['image']
|
169 |
+
if image.shape[-1] > 256:
|
170 |
+
scale = 256 / image.shape[-1]
|
171 |
+
image = F.interpolate(image, scale_factor=scale)
|
172 |
+
image = torch.sum(image, dim=1)
|
173 |
+
return image
|
174 |
+
|
175 |
+
jacobians = []
|
176 |
+
for idx in tqdm(range(latent_zs.shape[0])):
|
177 |
+
latent_w = latent_ws[idx:idx+1]
|
178 |
+
jac_i = jacobian(func=syn2jaco,
|
179 |
+
inputs=latent_w,
|
180 |
+
create_graph=False,
|
181 |
+
strict=False)
|
182 |
+
jacobians.append(jac_i)
|
183 |
+
if args.save_jpg:
|
184 |
+
wp = latent_w.unsqueeze(1).repeat((1, generator.num_layers, 1))
|
185 |
+
syn_outputs = generator.synthesis(wp)['image']
|
186 |
+
syn_outputs = to_numpy(syn_outputs)
|
187 |
+
images = postprocess_image(syn_outputs)
|
188 |
+
save_path = f'{save_dir}/images/{idx:06d}.jpg'
|
189 |
+
save_image(save_path, images[0])
|
190 |
+
jacobians = torch.cat(jacobians, dim=0)
|
191 |
+
jacobians = to_numpy(jacobians)
|
192 |
+
print(f'shape of the jacobian: {jacobians.shape}')
|
193 |
+
latent_ws = to_numpy(latent_ws)
|
194 |
+
np.save(f'{save_dir}/latent_codes.npy', latent_ws)
|
195 |
+
np.save(f'{save_dir}/jacobians_w.npy', jacobians)
|
196 |
+
print(f'Finish computing {args.nums} jacobians.')
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == '__main__':
|
200 |
+
main()
|
coordinate.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Utility functions to help define the region coordinates within an image."""
|
3 |
+
|
4 |
+
import os
|
5 |
+
from glob import glob
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
from tqdm import tqdm
|
10 |
+
from utils.parsing_utils import parse_index
|
11 |
+
|
12 |
+
|
13 |
+
def get_mask_by_coordinates(image_size, coordinate):
|
14 |
+
"""Get mask using the provided coordinates."""
|
15 |
+
mask = np.zeros([image_size, image_size], dtype=np.float32)
|
16 |
+
center_x, center_y = coordinate[0], coordinate[1]
|
17 |
+
crop_x, crop_y = coordinate[2], coordinate[3]
|
18 |
+
xx = center_x - crop_x // 2
|
19 |
+
yy = center_y - crop_y // 2
|
20 |
+
mask[xx:xx + crop_x, yy:yy + crop_y] = 1.
|
21 |
+
return mask
|
22 |
+
|
23 |
+
|
24 |
+
def get_mask_by_segmentation(seg_mask, label):
|
25 |
+
"""Get the mask using the segmentation array and labels."""
|
26 |
+
zeros = np.zeros_like(seg_mask)
|
27 |
+
ones = np.ones_like(seg_mask)
|
28 |
+
mask = np.where(seg_mask == label, ones, zeros)
|
29 |
+
return mask
|
30 |
+
|
31 |
+
|
32 |
+
def get_mask(image_size, coordinate=None, seg_mask=None, labels='1'):
|
33 |
+
"""Get mask using either the coordinate or the segmentation array."""
|
34 |
+
if coordinate is not None:
|
35 |
+
print('Using coordinate to get mask!')
|
36 |
+
mask = get_mask_by_coordinates(image_size, coordinate)
|
37 |
+
else:
|
38 |
+
print('Using segmentation to get the mask!')
|
39 |
+
print(f'Using label {labels}')
|
40 |
+
mask = np.zeros_like(seg_mask)
|
41 |
+
for label_ in labels:
|
42 |
+
mask += get_mask_by_segmentation(seg_mask, int(label_))
|
43 |
+
mask = np.clip(mask, a_min=0, a_max=1)
|
44 |
+
|
45 |
+
return mask
|
46 |
+
|
47 |
+
|
48 |
+
# For FFHQ [center_x, center_y, height, width]
|
49 |
+
# Those coordinates are suitable for both ffhq and metface.
|
50 |
+
COORDINATE_ffhq = {'left_eye': [120, 95, 20, 38],
|
51 |
+
'right_eye': [120, 159, 20, 38],
|
52 |
+
'eyes': [120, 128, 20, 115],
|
53 |
+
'nose': [142, 131, 40, 46],
|
54 |
+
'mouth': [184, 127, 30, 70],
|
55 |
+
'chin': [217, 130, 42, 110],
|
56 |
+
'eyebrow': [126, 105, 15, 118],
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
# For FFHQ unaligned
|
61 |
+
COORDINATE_ffhqu = {'eyesr2': [134, 116, 30, 115],
|
62 |
+
'eyesr3': [64, 128, 26, 115],
|
63 |
+
'eyest0': [70, 88, 30, 115],
|
64 |
+
'eyest3': [108, 142, 26, 115],
|
65 |
+
}
|
66 |
+
|
67 |
+
# [center_x, center_y, height, width]
|
68 |
+
COORDINATE_biggan = {'center0': [120, 120, 80, 80],
|
69 |
+
'center1': [120, 120, 130, 130],
|
70 |
+
'center2': [120, 120, 200, 200],
|
71 |
+
'left_side': [128, 64, 256, 128],
|
72 |
+
'top_side': [64, 128, 128, 256],
|
73 |
+
'head0': [89, 115, 49, 70],
|
74 |
+
'head1': [93, 110, 48, 70]}
|
75 |
+
|
76 |
+
|
77 |
+
COORDINATES = {'ffhq': COORDINATE_ffhq,
|
78 |
+
'ffhqu': COORDINATE_ffhqu,
|
79 |
+
'biggan': COORDINATE_biggan
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
def parse_args():
|
84 |
+
"""Parses arguments."""
|
85 |
+
|
86 |
+
parser = argparse.ArgumentParser()
|
87 |
+
parser.add_argument('--image_path', type=str, default='',
|
88 |
+
help='The path to the image.')
|
89 |
+
parser.add_argument('--mask_path', type=str, default='',
|
90 |
+
help='The path to the mask.')
|
91 |
+
parser.add_argument('--save_dir', type=str, default='',
|
92 |
+
help='The path to the image.')
|
93 |
+
parser.add_argument('--label', type=str, default=None,
|
94 |
+
help='The label number in the mask.')
|
95 |
+
parser.add_argument('--data', type=str, default='ffhq',
|
96 |
+
help='The name of the dataset to test.')
|
97 |
+
parser.add_argument('--num', type=int, default=0,
|
98 |
+
help='number of image to display.')
|
99 |
+
parser.add_argument('--img_type', type=str, default='jpeg',
|
100 |
+
help='Format of the image.')
|
101 |
+
|
102 |
+
return parser.parse_args()
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
"""Main function to show an image with masks"""
|
107 |
+
args = parse_args()
|
108 |
+
save_dir = args.save_dir or './temp_mask'
|
109 |
+
os.makedirs(save_dir, exist_ok=True)
|
110 |
+
images = sorted(glob(f'{args.image_path}/*.{args.img_type}'))[args.num:]
|
111 |
+
label_files = sorted(glob(f'{args.mask_path}/*.npy'))[args.num:]
|
112 |
+
COORDINATE = COORDINATES[args.data]
|
113 |
+
for i, image in tqdm(enumerate(images)):
|
114 |
+
img = cv2.imread(image)
|
115 |
+
im_name = image.split('/')[-1].split('.')[0]
|
116 |
+
if args.label is None:
|
117 |
+
for name, coord in COORDINATE.items():
|
118 |
+
if len(coord) == 0:
|
119 |
+
continue
|
120 |
+
mask = np.zeros(img.shape, dtype=np.float32)
|
121 |
+
center_x, center_y = coord[0], coord[1]
|
122 |
+
crop_x, crop_y = coord[2], coord[3]
|
123 |
+
xx = center_x - crop_x // 2
|
124 |
+
yy = center_y - crop_y // 2
|
125 |
+
mask[xx:xx + crop_x, yy:yy + crop_y, :] = 1.
|
126 |
+
img_ = img * mask
|
127 |
+
cv2.imwrite(f'{save_dir}/{im_name}_{name}.png', img_)
|
128 |
+
else:
|
129 |
+
print('Using segmentation to get the mask!')
|
130 |
+
seg_mask = np.load(label_files[i])
|
131 |
+
labels = parse_index(args.label)
|
132 |
+
print(f'Using label {labels}')
|
133 |
+
mask = np.zeros_like(seg_mask)
|
134 |
+
for label_ in labels:
|
135 |
+
mask += get_mask_by_segmentation(seg_mask, int(label_))
|
136 |
+
mask = np.clip(mask, a_min=0, a_max=1)
|
137 |
+
img_ = img * mask[:, :, np.newaxis]
|
138 |
+
cv2.imwrite(f'{save_dir}/{im_name}_{args.label}.png', img_)
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == '__main__':
|
142 |
+
main()
|
directions/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
directions/afhq/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
directions/afhq/stylegan3/eyes-r.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:959982185cd0401b8ab984ad2c1c22c01a494d8928a765e353c4fc34e9b079a2
|
3 |
+
size 2176
|
directions/ffhq/stylegan2/eyebrows.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8a7385ee951fa93b34044a768254a937327c87b0d86a569b62ae4593ae0b765
|
3 |
+
size 2176
|
directions/ffhq/stylegan2/eyesize.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74d417489944e2f2c6ee9622c4a3745edf20edad5837f5d4b2d89847003d0ec5
|
3 |
+
size 2176
|
directions/ffhq/stylegan2/gaze_direction.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98c7787363716096706db4f6cc039cf5ffc773142a0f790ed185300ddaf8fc0e
|
3 |
+
size 2176
|
directions/ffhq/stylegan2/lipstick.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ce301bb70f35e10c06b9d6fbb485644a65ebf1ac57a080e70fc3448b3343dc6
|
3 |
+
size 2176
|
directions/ffhq/stylegan2/mouth.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cf4fe4565fd2d3f363d0b2e7de1073f1379c8488113b99f8d3bad9ba53a0fa4
|
3 |
+
size 2176
|
directions/ffhq/stylegan2/nose_length.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a7d3685b38f18e87a49cc3a79106f18dbdf30b1e8d3db5df21c84d592566ea5
|
3 |
+
size 2176
|
directions/ffhq/stylegan3/eyes-r.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91f7c4432a030568b1aac414c30e7f00cf4d441228a625eb8d971f5f28b036db
|
3 |
+
size 2176
|
manipulate.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Manipulates synthesized or real images with existing boundary.
|
3 |
+
|
4 |
+
Support StyleGAN2 and StyleGAN3.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os.path
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from models import build_model
|
14 |
+
from utils.visualizers.html_visualizer import HtmlVisualizer
|
15 |
+
from utils.image_utils import save_image
|
16 |
+
from utils.parsing_utils import parse_index
|
17 |
+
from utils.image_utils import postprocess_image
|
18 |
+
from utils.custom_utils import to_numpy, linear_interpolate
|
19 |
+
from utils.custom_utils import make_transform
|
20 |
+
|
21 |
+
|
22 |
+
def parse_args():
|
23 |
+
"""Parses arguments."""
|
24 |
+
parser = argparse.ArgumentParser()
|
25 |
+
group = parser.add_argument_group('General options.')
|
26 |
+
group.add_argument('weight_path', type=str,
|
27 |
+
help='Weight path to the pre-trained model.')
|
28 |
+
group.add_argument('boundary_path', type=str,
|
29 |
+
help='Path to the attribute vectors.')
|
30 |
+
group.add_argument('--save_dir', type=str, default=None,
|
31 |
+
help='Directory to save the results. If not specified, '
|
32 |
+
'the results will be saved to '
|
33 |
+
'`work_dirs/{TASK_SPECIFIC}/` by default.')
|
34 |
+
group.add_argument('--job', type=str, default='manipulations',
|
35 |
+
help='Name for the job. (default: manipulations)')
|
36 |
+
group.add_argument('--seed', type=int, default=4,
|
37 |
+
help='Seed for sampling. (default: 4)')
|
38 |
+
group.add_argument('--nums', type=int, default=10,
|
39 |
+
help='Number of samples to synthesized. (default: 10)')
|
40 |
+
group.add_argument('--img_size', type=int, default=1024,
|
41 |
+
help='Size of the synthesized images. (default: 1024)')
|
42 |
+
group.add_argument('--vis_size', type=int, default=256,
|
43 |
+
help='Size of the visualize images. (default: 256)')
|
44 |
+
group.add_argument('--w_dim', type=int, default=512,
|
45 |
+
help='Dimension of the latent w. (default: 512)')
|
46 |
+
group.add_argument('--batch_size', type=int, default=4,
|
47 |
+
help='Batch size. (default: 4)')
|
48 |
+
group.add_argument('--save_jpg', action='store_true', default=False,
|
49 |
+
help='Whether to save raw image. (default: False)')
|
50 |
+
group.add_argument('-d', '--data_name', type=str, default='ffhq',
|
51 |
+
help='Name of the datasets. (default: ffhq)')
|
52 |
+
group.add_argument('--latent_path', type=str, default='',
|
53 |
+
help='Path to the given latent codes. (default: None)')
|
54 |
+
group.add_argument('--trunc_psi', type=float, default=0.7,
|
55 |
+
help='Psi factor used for truncation. (default: 0.7)')
|
56 |
+
group.add_argument('--trunc_layers', type=int, default=8,
|
57 |
+
help='Number of layers to perform truncation.'
|
58 |
+
' (default: 8)')
|
59 |
+
group.add_argument('--name', type=str, default='resefa',
|
60 |
+
help='Name of help save the results.')
|
61 |
+
|
62 |
+
group = parser.add_argument_group('StyleGAN2')
|
63 |
+
group.add_argument('--stylegan2', action='store_true',
|
64 |
+
help='Whether or not using StyleGAN2. (default: False)')
|
65 |
+
group.add_argument('--scale_stylegan2', type=float, default=1.0,
|
66 |
+
help='Scale for the number of channel fro stylegan2.')
|
67 |
+
group.add_argument('--randomize_noise', type=str, default='const',
|
68 |
+
help='Noise type when editing. (const or random)')
|
69 |
+
|
70 |
+
group = parser.add_argument_group('StyleGAN3')
|
71 |
+
group.add_argument('--stylegan3', action='store_true',
|
72 |
+
help='Whether or not using StyleGAN3. (default: False)')
|
73 |
+
group.add_argument('--cfg', type=str, default='T',
|
74 |
+
help='Config of the stylegan3 (T/R)')
|
75 |
+
group.add_argument('--scale_stylegan3r', type=float, default=2.0,
|
76 |
+
help='Scale for the number of channel for stylegan3 R.')
|
77 |
+
group.add_argument('--scale_stylegan3t', type=float, default=1.0,
|
78 |
+
help='Scale for the number of channel for stylegan3 T.')
|
79 |
+
group.add_argument('--tx', type=float, default=0,
|
80 |
+
help='Translate X-coordinate. (default: 0.0)')
|
81 |
+
group.add_argument('--ty', type=float, default=0,
|
82 |
+
help='Translate Y-coordinate. (default: 0.0)')
|
83 |
+
group.add_argument('--rotate', type=float, default=0,
|
84 |
+
help='Rotation angle in degrees. (default: 0)')
|
85 |
+
|
86 |
+
group = parser.add_argument_group('Manipulation')
|
87 |
+
group.add_argument('--mani_layers', type=str, default='4,5,6,7',
|
88 |
+
help='The layers will be manipulated.'
|
89 |
+
'(default: 4,5,6,7). For the eyebrow and lipstick,'
|
90 |
+
'using [8-11] layers instead.')
|
91 |
+
group.add_argument('--step', type=int, default=7,
|
92 |
+
help='Number of manipulation steps. (default: 7)')
|
93 |
+
group.add_argument('--start', type=int, default=0,
|
94 |
+
help='The start index of the manipulation directions.')
|
95 |
+
group.add_argument('--end', type=int, default=1,
|
96 |
+
help='The end index of the manipulation directions.')
|
97 |
+
group.add_argument('--start_distance', type=float, default=-10.0,
|
98 |
+
help='Start distance for manipulation. (default: -10.0)')
|
99 |
+
group.add_argument('--end_distance', type=float, default=10.0,
|
100 |
+
help='End distance for manipulation. (default: 10.0)')
|
101 |
+
|
102 |
+
return parser.parse_args()
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
"""Main function."""
|
107 |
+
args = parse_args()
|
108 |
+
# Parse model configuration.
|
109 |
+
assert (args.stylegan2 and not args.stylegan3) or \
|
110 |
+
(not args.stylegan2 and args.stylegan3)
|
111 |
+
checkpoint_path = args.weight_path
|
112 |
+
boundary_path = args.boundary_path
|
113 |
+
assert os.path.exists(checkpoint_path)
|
114 |
+
assert os.path.exists(boundary_path)
|
115 |
+
boundary_name = os.path.splitext(os.path.basename(boundary_path))[0]
|
116 |
+
job_disc = ''
|
117 |
+
if args.stylegan2:
|
118 |
+
config = dict(model_type='StyleGAN2Generator',
|
119 |
+
resolution=args.img_size,
|
120 |
+
w_dim=args.w_dim,
|
121 |
+
fmaps_base=int(args.scale_stylegan2 * (32 << 10)),
|
122 |
+
fmaps_max=512,)
|
123 |
+
job_disc += 'stylegan2'
|
124 |
+
else:
|
125 |
+
if args.stylegan3 and args.cfg == 'R':
|
126 |
+
config = dict(model_type='StyleGAN3Generator',
|
127 |
+
resolution=args.img_size,
|
128 |
+
w_dim=args.w_dim,
|
129 |
+
fmaps_base=int(args.scale_stylegan3r * (32 << 10)),
|
130 |
+
fmaps_max=1024,
|
131 |
+
use_radial_filter=True,)
|
132 |
+
job_disc += 'stylegan3r'
|
133 |
+
elif args.stylegan3 and args.cfg == 'T':
|
134 |
+
config = dict(model_type='StyleGAN3Generator',
|
135 |
+
resolution=args.img_size,
|
136 |
+
w_dim=args.w_dim,
|
137 |
+
fmaps_base=int(args.scale_stylegan3t * (32 << 10)),
|
138 |
+
fmaps_max=512,
|
139 |
+
use_radial_filter=False,
|
140 |
+
kernel_size=3,)
|
141 |
+
job_disc += 'stylegan3t'
|
142 |
+
else:
|
143 |
+
raise TypeError(f'StyleGAN3 config type error, need `R/T`,'
|
144 |
+
f' but got {args.cfg} instead.')
|
145 |
+
|
146 |
+
# Get work directory and job name.
|
147 |
+
save_dir = args.save_dir or f'work_dirs/{args.job}/{args.data_name}'
|
148 |
+
os.makedirs(save_dir, exist_ok=True)
|
149 |
+
job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}_{boundary_name}'
|
150 |
+
os.makedirs(f'{save_dir}/{job_name}', exist_ok=True)
|
151 |
+
|
152 |
+
print('Building generator...')
|
153 |
+
generator = build_model(**config)
|
154 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
155 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
|
156 |
+
if 'generator_smooth' in checkpoint:
|
157 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
158 |
+
else:
|
159 |
+
generator.load_state_dict(checkpoint['generator'])
|
160 |
+
generator = generator.eval().cuda()
|
161 |
+
print('Finish loading checkpoint.')
|
162 |
+
if args.stylegan3 and hasattr(generator.synthesis, 'early_layer'):
|
163 |
+
m = make_transform(args.tx, args.ty, args.rotate)
|
164 |
+
m = np.linalg.inv(m)
|
165 |
+
generator.synthesis.early_layer.transform.copy_(torch.from_numpy(m))
|
166 |
+
|
167 |
+
np.random.seed(args.seed)
|
168 |
+
torch.manual_seed(args.seed)
|
169 |
+
if os.path.exists(args.latent_path):
|
170 |
+
print(f'Load latent codes from {args.latent_path}')
|
171 |
+
latent_zs = np.load(args.latent_path)
|
172 |
+
latent_zs = latent_zs[:args.nums]
|
173 |
+
else:
|
174 |
+
print('Sampling latent code randomly')
|
175 |
+
latent_zs = np.random.randn(args.nums, generator.z_dim)
|
176 |
+
latent_zs = torch.from_numpy(latent_zs.astype(np.float32))
|
177 |
+
latent_zs = latent_zs.cuda()
|
178 |
+
num_images = latent_zs.shape[0]
|
179 |
+
wp = []
|
180 |
+
for idx in range(0, num_images, args.batch_size):
|
181 |
+
latent_z = latent_zs[idx:idx+args.batch_size]
|
182 |
+
latent_w_ = generator.mapping(latent_z, None)['wp']
|
183 |
+
wp.append(latent_w_)
|
184 |
+
wp = torch.cat(wp, dim=0)
|
185 |
+
trunc_psi = args.trunc_psi
|
186 |
+
trunc_layers = args.trunc_layers
|
187 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
188 |
+
w_avg = generator.w_avg
|
189 |
+
w_avg = w_avg.reshape(1, -1, generator.w_dim)[:, :trunc_layers]
|
190 |
+
wp[:, :trunc_layers] = w_avg.lerp(wp[:, :trunc_layers], trunc_psi)
|
191 |
+
print(f'Shape of the latent ws: {wp.shape}')
|
192 |
+
image_list = []
|
193 |
+
for i in range(num_images):
|
194 |
+
image_list.append(f'{i:06d}')
|
195 |
+
|
196 |
+
print('Loading boundary.')
|
197 |
+
directions = np.load(boundary_path)
|
198 |
+
layer_index = parse_index(args.mani_layers)
|
199 |
+
if not layer_index:
|
200 |
+
layer_index = list(range(generator.num_layers - 1))
|
201 |
+
print(f'Manipulating on layers `{layer_index}`.')
|
202 |
+
|
203 |
+
vis_size = None if args.vis_size == 0 else args.vis_size
|
204 |
+
delta_num = args.end - args.start
|
205 |
+
visualizer = HtmlVisualizer(num_rows=num_images * delta_num,
|
206 |
+
num_cols=args.step + 2,
|
207 |
+
image_size=vis_size)
|
208 |
+
visualizer.set_headers(
|
209 |
+
['Name', 'Origin'] +
|
210 |
+
[f'Step {i:02d}' for i in range(1, args.step + 1)]
|
211 |
+
)
|
212 |
+
# Manipulate images.
|
213 |
+
print('Start manipulation.')
|
214 |
+
for row in tqdm(range(num_images)):
|
215 |
+
latent_w = wp[row:row+1]
|
216 |
+
images_ori = generator.synthesis(latent_w)['image']
|
217 |
+
images_ori = postprocess_image(to_numpy(images_ori))
|
218 |
+
if args.save_jpg:
|
219 |
+
save_image(f'{save_dir}/{job_name}/{row:06d}_orin.jpg',
|
220 |
+
images_ori[0])
|
221 |
+
for num_direc in range(args.start, args.end):
|
222 |
+
html_row = num_direc - args.start
|
223 |
+
direction = directions[num_direc:num_direc+1]
|
224 |
+
direction = np.tile(direction, [1, generator.num_layers, 1])
|
225 |
+
visualizer.set_cell(row * delta_num + html_row, 0,
|
226 |
+
text=f'{image_list[row]}_{num_direc:03d}')
|
227 |
+
visualizer.set_cell(row * delta_num + html_row, 1,
|
228 |
+
image=images_ori[0])
|
229 |
+
mani_codes = linear_interpolate(latent_code=to_numpy(latent_w),
|
230 |
+
boundary=direction,
|
231 |
+
layer_index=layer_index,
|
232 |
+
start_distance=args.start_distance,
|
233 |
+
end_distance=args.end_distance,
|
234 |
+
steps=args.step)
|
235 |
+
mani_codes = torch.from_numpy(mani_codes.astype(np.float32)).cuda()
|
236 |
+
for idx in range(0, mani_codes.shape[0], args.batch_size):
|
237 |
+
codes_ = mani_codes[idx:idx+args.batch_size]
|
238 |
+
images_ = generator.synthesis(codes_)['image']
|
239 |
+
images_ = postprocess_image(to_numpy(images_))
|
240 |
+
for i in range(images_.shape[0]):
|
241 |
+
visualizer.set_cell(row * delta_num + html_row, idx+i+2,
|
242 |
+
image=images_[i])
|
243 |
+
if args.save_jpg:
|
244 |
+
save_image(f'{save_dir}/{job_name}/{row:06d}_ind_'
|
245 |
+
f'{num_direc:06d}_mani_{idx+i:06d}.jpg',
|
246 |
+
images_[i])
|
247 |
+
# Save results.
|
248 |
+
np.save(f'{save_dir}/{job_name}/latent_codes.npy', to_numpy(wp))
|
249 |
+
visualizer.save(f'{save_dir}/{job_name}_{args.name}.html')
|
250 |
+
|
251 |
+
|
252 |
+
if __name__ == '__main__':
|
253 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Collects all models."""
|
3 |
+
|
4 |
+
from .pggan_generator import PGGANGenerator
|
5 |
+
from .pggan_discriminator import PGGANDiscriminator
|
6 |
+
from .stylegan_generator import StyleGANGenerator
|
7 |
+
from .stylegan_discriminator import StyleGANDiscriminator
|
8 |
+
from .stylegan2_generator import StyleGAN2Generator
|
9 |
+
from .stylegan2_discriminator import StyleGAN2Discriminator
|
10 |
+
from .stylegan3_generator import StyleGAN3Generator
|
11 |
+
from .ghfeat_encoder import GHFeatEncoder
|
12 |
+
from .perceptual_model import PerceptualModel
|
13 |
+
from .inception_model import InceptionModel
|
14 |
+
|
15 |
+
__all__ = ['build_model']
|
16 |
+
|
17 |
+
_MODELS = {
|
18 |
+
'PGGANGenerator': PGGANGenerator,
|
19 |
+
'PGGANDiscriminator': PGGANDiscriminator,
|
20 |
+
'StyleGANGenerator': StyleGANGenerator,
|
21 |
+
'StyleGANDiscriminator': StyleGANDiscriminator,
|
22 |
+
'StyleGAN2Generator': StyleGAN2Generator,
|
23 |
+
'StyleGAN2Discriminator': StyleGAN2Discriminator,
|
24 |
+
'StyleGAN3Generator': StyleGAN3Generator,
|
25 |
+
'GHFeatEncoder': GHFeatEncoder,
|
26 |
+
'PerceptualModel': PerceptualModel.build_model,
|
27 |
+
'InceptionModel': InceptionModel.build_model
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def build_model(model_type, **kwargs):
|
32 |
+
"""Builds a model based on its class type.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
model_type: Class type to which the model belongs, which is case
|
36 |
+
sensitive.
|
37 |
+
**kwargs: Additional arguments to build the model.
|
38 |
+
|
39 |
+
Raises:
|
40 |
+
ValueError: If the `model_type` is not supported.
|
41 |
+
"""
|
42 |
+
if model_type not in _MODELS:
|
43 |
+
raise ValueError(f'Invalid model type: `{model_type}`!\n'
|
44 |
+
f'Types allowed: {list(_MODELS)}.')
|
45 |
+
return _MODELS[model_type](**kwargs)
|
models/ghfeat_encoder.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of encoder used in GH-Feat (including IDInvert).
|
3 |
+
|
4 |
+
ResNet is used as the backbone.
|
5 |
+
|
6 |
+
GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf
|
7 |
+
IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf
|
8 |
+
|
9 |
+
NOTE: Please use `latent_num` and `num_latents_per_head` to control the
|
10 |
+
inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert.
|
11 |
+
In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.distributed as dist
|
20 |
+
|
21 |
+
__all__ = ['GHFeatEncoder']
|
22 |
+
|
23 |
+
# Resolutions allowed.
|
24 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
25 |
+
|
26 |
+
# pylint: disable=missing-function-docstring
|
27 |
+
|
28 |
+
class BasicBlock(nn.Module):
|
29 |
+
"""Implementation of ResNet BasicBlock."""
|
30 |
+
|
31 |
+
expansion = 1
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
inplanes,
|
35 |
+
planes,
|
36 |
+
base_width=64,
|
37 |
+
stride=1,
|
38 |
+
groups=1,
|
39 |
+
dilation=1,
|
40 |
+
norm_layer=None,
|
41 |
+
downsample=None):
|
42 |
+
super().__init__()
|
43 |
+
if base_width != 64:
|
44 |
+
raise ValueError(f'BasicBlock of ResNet only supports '
|
45 |
+
f'`base_width=64`, but {base_width} received!')
|
46 |
+
if stride not in [1, 2]:
|
47 |
+
raise ValueError(f'BasicBlock of ResNet only supports `stride=1` '
|
48 |
+
f'and `stride=2`, but {stride} received!')
|
49 |
+
if groups != 1:
|
50 |
+
raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, '
|
51 |
+
f'but {groups} received!')
|
52 |
+
if dilation != 1:
|
53 |
+
raise ValueError(f'BasicBlock of ResNet only supports '
|
54 |
+
f'`dilation=1`, but {dilation} received!')
|
55 |
+
assert self.expansion == 1
|
56 |
+
|
57 |
+
self.stride = stride
|
58 |
+
if norm_layer is None:
|
59 |
+
norm_layer = nn.BatchNorm2d
|
60 |
+
self.conv1 = nn.Conv2d(in_channels=inplanes,
|
61 |
+
out_channels=planes,
|
62 |
+
kernel_size=3,
|
63 |
+
stride=stride,
|
64 |
+
padding=1,
|
65 |
+
groups=1,
|
66 |
+
dilation=1,
|
67 |
+
bias=False)
|
68 |
+
self.bn1 = norm_layer(planes)
|
69 |
+
self.relu = nn.ReLU(inplace=True)
|
70 |
+
self.conv2 = nn.Conv2d(in_channels=planes,
|
71 |
+
out_channels=planes,
|
72 |
+
kernel_size=3,
|
73 |
+
stride=1,
|
74 |
+
padding=1,
|
75 |
+
groups=1,
|
76 |
+
dilation=1,
|
77 |
+
bias=False)
|
78 |
+
self.bn2 = norm_layer(planes)
|
79 |
+
self.downsample = downsample
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
identity = self.downsample(x) if self.downsample is not None else x
|
83 |
+
|
84 |
+
out = self.conv1(x)
|
85 |
+
out = self.bn1(out)
|
86 |
+
out = self.relu(out)
|
87 |
+
|
88 |
+
out = self.conv2(out)
|
89 |
+
out = self.bn2(out)
|
90 |
+
out = self.relu(out + identity)
|
91 |
+
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class Bottleneck(nn.Module):
|
96 |
+
"""Implementation of ResNet Bottleneck."""
|
97 |
+
|
98 |
+
expansion = 4
|
99 |
+
|
100 |
+
def __init__(self,
|
101 |
+
inplanes,
|
102 |
+
planes,
|
103 |
+
base_width=64,
|
104 |
+
stride=1,
|
105 |
+
groups=1,
|
106 |
+
dilation=1,
|
107 |
+
norm_layer=None,
|
108 |
+
downsample=None):
|
109 |
+
super().__init__()
|
110 |
+
if stride not in [1, 2]:
|
111 |
+
raise ValueError(f'Bottleneck of ResNet only supports `stride=1` '
|
112 |
+
f'and `stride=2`, but {stride} received!')
|
113 |
+
|
114 |
+
width = int(planes * (base_width / 64)) * groups
|
115 |
+
self.stride = stride
|
116 |
+
if norm_layer is None:
|
117 |
+
norm_layer = nn.BatchNorm2d
|
118 |
+
self.conv1 = nn.Conv2d(in_channels=inplanes,
|
119 |
+
out_channels=width,
|
120 |
+
kernel_size=1,
|
121 |
+
stride=1,
|
122 |
+
padding=0,
|
123 |
+
dilation=1,
|
124 |
+
groups=1,
|
125 |
+
bias=False)
|
126 |
+
self.bn1 = norm_layer(width)
|
127 |
+
self.conv2 = nn.Conv2d(in_channels=width,
|
128 |
+
out_channels=width,
|
129 |
+
kernel_size=3,
|
130 |
+
stride=stride,
|
131 |
+
padding=dilation,
|
132 |
+
groups=groups,
|
133 |
+
dilation=dilation,
|
134 |
+
bias=False)
|
135 |
+
self.bn2 = norm_layer(width)
|
136 |
+
self.conv3 = nn.Conv2d(in_channels=width,
|
137 |
+
out_channels=planes * self.expansion,
|
138 |
+
kernel_size=1,
|
139 |
+
stride=1,
|
140 |
+
padding=0,
|
141 |
+
dilation=1,
|
142 |
+
groups=1,
|
143 |
+
bias=False)
|
144 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
145 |
+
self.relu = nn.ReLU(inplace=True)
|
146 |
+
self.downsample = downsample
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
identity = self.downsample(x) if self.downsample is not None else x
|
150 |
+
|
151 |
+
out = self.conv1(x)
|
152 |
+
out = self.bn1(out)
|
153 |
+
out = self.relu(out)
|
154 |
+
|
155 |
+
out = self.conv2(out)
|
156 |
+
out = self.bn2(out)
|
157 |
+
out = self.relu(out)
|
158 |
+
|
159 |
+
out = self.conv3(out)
|
160 |
+
out = self.bn3(out)
|
161 |
+
out = self.relu(out + identity)
|
162 |
+
|
163 |
+
return out
|
164 |
+
|
165 |
+
|
166 |
+
class GHFeatEncoder(nn.Module):
|
167 |
+
"""Define the ResNet-based encoder network for GAN inversion.
|
168 |
+
|
169 |
+
On top of the backbone, there are several task-heads to produce inverted
|
170 |
+
codes. Please use `latent_dim` and `num_latents_per_head` to define the
|
171 |
+
structure. For example, `latent_dim = [512] * 14` and
|
172 |
+
`num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with
|
173 |
+
14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers,
|
174 |
+
respectively) are used.
|
175 |
+
|
176 |
+
Settings for the encoder network:
|
177 |
+
|
178 |
+
(1) resolution: The resolution of the output image.
|
179 |
+
(2) latent_dim: Dimension of the latent space. A number (one code will be
|
180 |
+
produced), or a list of numbers regarding layer-wise latent codes.
|
181 |
+
(3) num_latents_per_head: Number of latents that is produced by each head.
|
182 |
+
(4) image_channels: Number of channels of the output image. (default: 3)
|
183 |
+
(5) final_res: Final resolution of the convolutional layers. (default: 4)
|
184 |
+
|
185 |
+
ResNet-related settings:
|
186 |
+
|
187 |
+
(1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18)
|
188 |
+
(2) inplanes: Number of channels of the first convolutional layer.
|
189 |
+
(default: 64)
|
190 |
+
(3) groups: Groups of the convolution, used in ResNet. (default: 1)
|
191 |
+
(4) width_per_group: Number of channels per group, used in ResNet.
|
192 |
+
(default: 64)
|
193 |
+
(5) replace_stride_with_dilation: Whether to replace stride with dilation,
|
194 |
+
used in ResNet. (default: None)
|
195 |
+
(6) norm_layer: Normalization layer used in the encoder. If set as `None`,
|
196 |
+
`nn.BatchNorm2d` will be used. Also, please NOTE that when using batch
|
197 |
+
normalization, the batch size is required to be larger than one for
|
198 |
+
training. (default: nn.BatchNorm2d)
|
199 |
+
(7) max_channels: Maximum number of channels in each layer. (default: 512)
|
200 |
+
|
201 |
+
Task-head related settings:
|
202 |
+
|
203 |
+
(1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting
|
204 |
+
the latent code. (default: True)
|
205 |
+
(2) fpn_channels: Number of channels used in FPN. (default: 512)
|
206 |
+
(3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting
|
207 |
+
the latent code. (default: True)
|
208 |
+
(4) sam_channels: Number of channels used in SAM. (default: 512)
|
209 |
+
"""
|
210 |
+
|
211 |
+
arch_settings = {
|
212 |
+
18: (BasicBlock, [2, 2, 2, 2]),
|
213 |
+
34: (BasicBlock, [3, 4, 6, 3]),
|
214 |
+
50: (Bottleneck, [3, 4, 6, 3]),
|
215 |
+
101: (Bottleneck, [3, 4, 23, 3]),
|
216 |
+
152: (Bottleneck, [3, 8, 36, 3])
|
217 |
+
}
|
218 |
+
|
219 |
+
def __init__(self,
|
220 |
+
resolution,
|
221 |
+
latent_dim,
|
222 |
+
num_latents_per_head,
|
223 |
+
image_channels=3,
|
224 |
+
final_res=4,
|
225 |
+
network_depth=18,
|
226 |
+
inplanes=64,
|
227 |
+
groups=1,
|
228 |
+
width_per_group=64,
|
229 |
+
replace_stride_with_dilation=None,
|
230 |
+
norm_layer=nn.BatchNorm2d,
|
231 |
+
max_channels=512,
|
232 |
+
use_fpn=True,
|
233 |
+
fpn_channels=512,
|
234 |
+
use_sam=True,
|
235 |
+
sam_channels=512):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
239 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
240 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
241 |
+
if network_depth not in self.arch_settings:
|
242 |
+
raise ValueError(f'Invalid network depth: `{network_depth}`!\n'
|
243 |
+
f'Options allowed: '
|
244 |
+
f'{list(self.arch_settings.keys())}.')
|
245 |
+
if isinstance(latent_dim, int):
|
246 |
+
latent_dim = [latent_dim]
|
247 |
+
assert isinstance(latent_dim, (list, tuple))
|
248 |
+
assert isinstance(num_latents_per_head, (list, tuple))
|
249 |
+
assert sum(num_latents_per_head) == len(latent_dim)
|
250 |
+
|
251 |
+
self.resolution = resolution
|
252 |
+
self.latent_dim = latent_dim
|
253 |
+
self.num_latents_per_head = num_latents_per_head
|
254 |
+
self.num_heads = len(self.num_latents_per_head)
|
255 |
+
self.image_channels = image_channels
|
256 |
+
self.final_res = final_res
|
257 |
+
self.inplanes = inplanes
|
258 |
+
self.network_depth = network_depth
|
259 |
+
self.groups = groups
|
260 |
+
self.dilation = 1
|
261 |
+
self.base_width = width_per_group
|
262 |
+
self.replace_stride_with_dilation = replace_stride_with_dilation
|
263 |
+
if norm_layer is None:
|
264 |
+
norm_layer = nn.BatchNorm2d
|
265 |
+
if norm_layer == nn.BatchNorm2d and dist.is_initialized():
|
266 |
+
norm_layer = nn.SyncBatchNorm
|
267 |
+
self.norm_layer = norm_layer
|
268 |
+
self.max_channels = max_channels
|
269 |
+
self.use_fpn = use_fpn
|
270 |
+
self.fpn_channels = fpn_channels
|
271 |
+
self.use_sam = use_sam
|
272 |
+
self.sam_channels = sam_channels
|
273 |
+
|
274 |
+
block_fn, num_blocks_per_stage = self.arch_settings[network_depth]
|
275 |
+
|
276 |
+
self.num_stages = int(np.log2(resolution // final_res)) - 1
|
277 |
+
# Add one block for additional stages.
|
278 |
+
for i in range(len(num_blocks_per_stage), self.num_stages):
|
279 |
+
num_blocks_per_stage.append(1)
|
280 |
+
if replace_stride_with_dilation is None:
|
281 |
+
replace_stride_with_dilation = [False] * self.num_stages
|
282 |
+
|
283 |
+
# Backbone.
|
284 |
+
self.conv1 = nn.Conv2d(in_channels=self.image_channels,
|
285 |
+
out_channels=self.inplanes,
|
286 |
+
kernel_size=7,
|
287 |
+
stride=2,
|
288 |
+
padding=3,
|
289 |
+
bias=False)
|
290 |
+
self.bn1 = norm_layer(self.inplanes)
|
291 |
+
self.relu = nn.ReLU(inplace=True)
|
292 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
293 |
+
|
294 |
+
self.stage_channels = [self.inplanes]
|
295 |
+
self.stages = nn.ModuleList()
|
296 |
+
for i in range(self.num_stages):
|
297 |
+
inplanes = self.inplanes if i == 0 else planes * block_fn.expansion
|
298 |
+
planes = min(self.max_channels, self.inplanes * (2 ** i))
|
299 |
+
num_blocks = num_blocks_per_stage[i]
|
300 |
+
stride = 1 if i == 0 else 2
|
301 |
+
dilate = replace_stride_with_dilation[i]
|
302 |
+
self.stages.append(self._make_stage(block_fn=block_fn,
|
303 |
+
inplanes=inplanes,
|
304 |
+
planes=planes,
|
305 |
+
num_blocks=num_blocks,
|
306 |
+
stride=stride,
|
307 |
+
dilate=dilate))
|
308 |
+
self.stage_channels.append(planes * block_fn.expansion)
|
309 |
+
|
310 |
+
if self.num_heads > len(self.stage_channels):
|
311 |
+
raise ValueError('Number of task heads is larger than number of '
|
312 |
+
'stages! Please reduce the number of heads.')
|
313 |
+
|
314 |
+
# Task-head.
|
315 |
+
if self.num_heads == 1:
|
316 |
+
self.use_fpn = False
|
317 |
+
self.use_sam = False
|
318 |
+
|
319 |
+
if self.use_fpn:
|
320 |
+
fpn_pyramid_channels = self.stage_channels[-self.num_heads:]
|
321 |
+
self.fpn = FPN(pyramid_channels=fpn_pyramid_channels,
|
322 |
+
out_channels=self.fpn_channels)
|
323 |
+
if self.use_sam:
|
324 |
+
if self.use_fpn:
|
325 |
+
sam_pyramid_channels = [self.fpn_channels] * self.num_heads
|
326 |
+
else:
|
327 |
+
sam_pyramid_channels = self.stage_channels[-self.num_heads:]
|
328 |
+
self.sam = SAM(pyramid_channels=sam_pyramid_channels,
|
329 |
+
out_channels=self.sam_channels)
|
330 |
+
|
331 |
+
self.heads = nn.ModuleList()
|
332 |
+
for head_idx in range(self.num_heads):
|
333 |
+
# Parse in_channels.
|
334 |
+
if self.use_sam:
|
335 |
+
in_channels = self.sam_channels
|
336 |
+
elif self.use_fpn:
|
337 |
+
in_channels = self.fpn_channels
|
338 |
+
else:
|
339 |
+
in_channels = self.stage_channels[head_idx - self.num_heads]
|
340 |
+
in_channels = in_channels * final_res * final_res
|
341 |
+
|
342 |
+
# Parse out_channels.
|
343 |
+
start_latent_idx = sum(self.num_latents_per_head[:head_idx])
|
344 |
+
end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
|
345 |
+
out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx])
|
346 |
+
|
347 |
+
self.heads.append(CodeHead(in_channels=in_channels,
|
348 |
+
out_channels=out_channels,
|
349 |
+
norm_layer=self.norm_layer))
|
350 |
+
|
351 |
+
def _make_stage(self,
|
352 |
+
block_fn,
|
353 |
+
inplanes,
|
354 |
+
planes,
|
355 |
+
num_blocks,
|
356 |
+
stride,
|
357 |
+
dilate):
|
358 |
+
norm_layer = self.norm_layer
|
359 |
+
downsample = None
|
360 |
+
previous_dilation = self.dilation
|
361 |
+
if dilate:
|
362 |
+
self.dilation *= stride
|
363 |
+
stride = 1
|
364 |
+
if stride != 1 or inplanes != planes * block_fn.expansion:
|
365 |
+
downsample = nn.Sequential(
|
366 |
+
nn.Conv2d(in_channels=inplanes,
|
367 |
+
out_channels=planes * block_fn.expansion,
|
368 |
+
kernel_size=1,
|
369 |
+
stride=stride,
|
370 |
+
padding=0,
|
371 |
+
dilation=1,
|
372 |
+
groups=1,
|
373 |
+
bias=False),
|
374 |
+
norm_layer(planes * block_fn.expansion),
|
375 |
+
)
|
376 |
+
|
377 |
+
blocks = []
|
378 |
+
blocks.append(block_fn(inplanes=inplanes,
|
379 |
+
planes=planes,
|
380 |
+
base_width=self.base_width,
|
381 |
+
stride=stride,
|
382 |
+
groups=self.groups,
|
383 |
+
dilation=previous_dilation,
|
384 |
+
norm_layer=norm_layer,
|
385 |
+
downsample=downsample))
|
386 |
+
for _ in range(1, num_blocks):
|
387 |
+
blocks.append(block_fn(inplanes=planes * block_fn.expansion,
|
388 |
+
planes=planes,
|
389 |
+
base_width=self.base_width,
|
390 |
+
stride=1,
|
391 |
+
groups=self.groups,
|
392 |
+
dilation=self.dilation,
|
393 |
+
norm_layer=norm_layer,
|
394 |
+
downsample=None))
|
395 |
+
|
396 |
+
return nn.Sequential(*blocks)
|
397 |
+
|
398 |
+
def forward(self, x):
|
399 |
+
x = self.conv1(x)
|
400 |
+
x = self.bn1(x)
|
401 |
+
x = self.relu(x)
|
402 |
+
x = self.maxpool(x)
|
403 |
+
|
404 |
+
features = [x]
|
405 |
+
for i in range(self.num_stages):
|
406 |
+
x = self.stages[i](x)
|
407 |
+
features.append(x)
|
408 |
+
features = features[-self.num_heads:]
|
409 |
+
|
410 |
+
if self.use_fpn:
|
411 |
+
features = self.fpn(features)
|
412 |
+
if self.use_sam:
|
413 |
+
features = self.sam(features)
|
414 |
+
else:
|
415 |
+
final_size = features[-1].shape[2:]
|
416 |
+
for i in range(self.num_heads - 1):
|
417 |
+
features[i] = F.adaptive_avg_pool2d(features[i], final_size)
|
418 |
+
|
419 |
+
outputs = []
|
420 |
+
for head_idx in range(self.num_heads):
|
421 |
+
codes = self.heads[head_idx](features[head_idx])
|
422 |
+
start_latent_idx = sum(self.num_latents_per_head[:head_idx])
|
423 |
+
end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
|
424 |
+
split_size = self.latent_dim[start_latent_idx:end_latent_idx]
|
425 |
+
outputs.extend(torch.split(codes, split_size, dim=1))
|
426 |
+
max_dim = max(self.latent_dim)
|
427 |
+
for i, dim in enumerate(self.latent_dim):
|
428 |
+
if dim < max_dim:
|
429 |
+
outputs[i] = F.pad(outputs[i], (0, max_dim - dim))
|
430 |
+
outputs[i] = outputs[i].unsqueeze(1)
|
431 |
+
|
432 |
+
return torch.cat(outputs, dim=1)
|
433 |
+
|
434 |
+
|
435 |
+
class FPN(nn.Module):
|
436 |
+
"""Implementation of Feature Pyramid Network (FPN).
|
437 |
+
|
438 |
+
The input of this module is a pyramid of features with reducing resolutions.
|
439 |
+
Then, this module fuses these multi-level features from `top_level` to
|
440 |
+
`bottom_level`. In particular, starting from the `top_level`, each feature
|
441 |
+
is convoluted, upsampled, and fused into its previous feature (which is also
|
442 |
+
convoluted).
|
443 |
+
|
444 |
+
Args:
|
445 |
+
pyramid_channels: A list of integers, each of which indicates the number
|
446 |
+
of channels of the feature from a particular level.
|
447 |
+
out_channels: Number of channels for each output.
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
A list of feature maps, each of which has `out_channels` channels.
|
451 |
+
"""
|
452 |
+
|
453 |
+
def __init__(self, pyramid_channels, out_channels):
|
454 |
+
super().__init__()
|
455 |
+
assert isinstance(pyramid_channels, (list, tuple))
|
456 |
+
self.num_levels = len(pyramid_channels)
|
457 |
+
|
458 |
+
self.lateral_layers = nn.ModuleList()
|
459 |
+
self.feature_layers = nn.ModuleList()
|
460 |
+
for i in range(self.num_levels):
|
461 |
+
in_channels = pyramid_channels[i]
|
462 |
+
self.lateral_layers.append(nn.Conv2d(in_channels=in_channels,
|
463 |
+
out_channels=out_channels,
|
464 |
+
kernel_size=3,
|
465 |
+
padding=1,
|
466 |
+
bias=True))
|
467 |
+
self.feature_layers.append(nn.Conv2d(in_channels=out_channels,
|
468 |
+
out_channels=out_channels,
|
469 |
+
kernel_size=3,
|
470 |
+
padding=1,
|
471 |
+
bias=True))
|
472 |
+
|
473 |
+
def forward(self, inputs):
|
474 |
+
if len(inputs) != self.num_levels:
|
475 |
+
raise ValueError('Number of inputs and `num_levels` mismatch!')
|
476 |
+
|
477 |
+
# Project all related features to `out_channels`.
|
478 |
+
laterals = []
|
479 |
+
for i in range(self.num_levels):
|
480 |
+
laterals.append(self.lateral_layers[i](inputs[i]))
|
481 |
+
|
482 |
+
# Fusion, starting from `top_level`.
|
483 |
+
for i in range(self.num_levels - 1, 0, -1):
|
484 |
+
scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2]
|
485 |
+
laterals[i - 1] = (laterals[i - 1] +
|
486 |
+
F.interpolate(laterals[i],
|
487 |
+
mode='nearest',
|
488 |
+
scale_factor=scale_factor))
|
489 |
+
|
490 |
+
# Get outputs.
|
491 |
+
outputs = []
|
492 |
+
for i, lateral in enumerate(laterals):
|
493 |
+
outputs.append(self.feature_layers[i](lateral))
|
494 |
+
|
495 |
+
return outputs
|
496 |
+
|
497 |
+
|
498 |
+
class SAM(nn.Module):
|
499 |
+
"""Implementation of Spatial Alignment Module (SAM).
|
500 |
+
|
501 |
+
The input of this module is a pyramid of features with reducing resolutions.
|
502 |
+
Then this module downsamples all levels of feature to the minimum resolution
|
503 |
+
and fuses it with the smallest feature map.
|
504 |
+
|
505 |
+
Args:
|
506 |
+
pyramid_channels: A list of integers, each of which indicates the number
|
507 |
+
of channels of the feature from a particular level.
|
508 |
+
out_channels: Number of channels for each output.
|
509 |
+
|
510 |
+
Returns:
|
511 |
+
A list of feature maps, each of which has `out_channels` channels.
|
512 |
+
"""
|
513 |
+
|
514 |
+
def __init__(self, pyramid_channels, out_channels):
|
515 |
+
super().__init__()
|
516 |
+
assert isinstance(pyramid_channels, (list, tuple))
|
517 |
+
self.num_levels = len(pyramid_channels)
|
518 |
+
|
519 |
+
self.fusion_layers = nn.ModuleList()
|
520 |
+
for i in range(self.num_levels):
|
521 |
+
in_channels = pyramid_channels[i]
|
522 |
+
self.fusion_layers.append(nn.Conv2d(in_channels=in_channels,
|
523 |
+
out_channels=out_channels,
|
524 |
+
kernel_size=3,
|
525 |
+
padding=1,
|
526 |
+
bias=True))
|
527 |
+
|
528 |
+
def forward(self, inputs):
|
529 |
+
if len(inputs) != self.num_levels:
|
530 |
+
raise ValueError('Number of inputs and `num_levels` mismatch!')
|
531 |
+
|
532 |
+
output_res = inputs[-1].shape[2:]
|
533 |
+
for i in range(self.num_levels - 1, -1, -1):
|
534 |
+
if i != self.num_levels - 1:
|
535 |
+
inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res)
|
536 |
+
inputs[i] = self.fusion_layers[i](inputs[i])
|
537 |
+
if i != self.num_levels - 1:
|
538 |
+
inputs[i] = inputs[i] + inputs[-1]
|
539 |
+
|
540 |
+
return inputs
|
541 |
+
|
542 |
+
|
543 |
+
class CodeHead(nn.Module):
|
544 |
+
"""Implementation of the task-head to produce inverted codes."""
|
545 |
+
|
546 |
+
def __init__(self, in_channels, out_channels, norm_layer):
|
547 |
+
super().__init__()
|
548 |
+
self.fc = nn.Linear(in_channels, out_channels, bias=True)
|
549 |
+
if norm_layer is None:
|
550 |
+
self.norm = nn.Identity()
|
551 |
+
else:
|
552 |
+
self.norm = norm_layer(out_channels)
|
553 |
+
|
554 |
+
def forward(self, x):
|
555 |
+
if x.ndim > 2:
|
556 |
+
x = x.flatten(start_dim=1)
|
557 |
+
latent = self.fc(x)
|
558 |
+
latent = latent.unsqueeze(2).unsqueeze(3)
|
559 |
+
latent = self.norm(latent)
|
560 |
+
|
561 |
+
return latent.flatten(start_dim=1)
|
562 |
+
|
563 |
+
# pylint: enable=missing-function-docstring
|
models/inception_model.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the Inception V3 model, which is used for inference ONLY.
|
3 |
+
|
4 |
+
This file is mostly borrowed from `torchvision/models/inception.py`.
|
5 |
+
|
6 |
+
Inception model is widely used to compute FID or IS metric for evaluating
|
7 |
+
generative models. However, the pre-trained models from torchvision is slightly
|
8 |
+
different from the TensorFlow version
|
9 |
+
|
10 |
+
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
11 |
+
|
12 |
+
which is used by the official FID implementation
|
13 |
+
|
14 |
+
https://github.com/bioinf-jku/TTUR
|
15 |
+
|
16 |
+
In particular:
|
17 |
+
|
18 |
+
(1) The number of classes in TensorFlow model is 1008 instead of 1000.
|
19 |
+
(2) The avg_pool() layers in TensorFlow model does not include the padded zero.
|
20 |
+
(3) The last Inception E Block in TensorFlow model use max_pool() instead of
|
21 |
+
avg_pool().
|
22 |
+
|
23 |
+
Hence, to align the evaluation results with those from TensorFlow
|
24 |
+
implementation, we modified the inception model to support both versions. Please
|
25 |
+
use `align_tf` argument to control the version.
|
26 |
+
"""
|
27 |
+
|
28 |
+
import warnings
|
29 |
+
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
import torch.nn.functional as F
|
33 |
+
import torch.distributed as dist
|
34 |
+
|
35 |
+
from utils.misc import download_url
|
36 |
+
|
37 |
+
__all__ = ['InceptionModel']
|
38 |
+
|
39 |
+
# pylint: disable=line-too-long
|
40 |
+
|
41 |
+
_MODEL_URL_SHA256 = {
|
42 |
+
# This model is provided by `torchvision`, which is ported from TensorFlow.
|
43 |
+
'torchvision_official': (
|
44 |
+
'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
45 |
+
'1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256
|
46 |
+
),
|
47 |
+
|
48 |
+
# This model is provided by https://github.com/mseitzer/pytorch-fid
|
49 |
+
'tf_inception_v3': (
|
50 |
+
'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth',
|
51 |
+
'6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256
|
52 |
+
)
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
class InceptionModel(object):
|
57 |
+
"""Defines the Inception (V3) model.
|
58 |
+
|
59 |
+
This is a static class, which is used to avoid this model to be built
|
60 |
+
repeatedly. Consequently, this model is particularly used for inference,
|
61 |
+
like computing FID. If training is required, please use the model from
|
62 |
+
`torchvision.models` or implement by yourself.
|
63 |
+
|
64 |
+
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
|
65 |
+
order and pixel range [-1, 1], and will also resize the images to shape
|
66 |
+
[299, 299] automatically. If your input is normalized by subtracting
|
67 |
+
(0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use
|
68 |
+
`transform_input` in the `forward()` function to un-normalize it.
|
69 |
+
"""
|
70 |
+
models = dict()
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def build_model(align_tf=True):
|
74 |
+
"""Builds the model and load pre-trained weights.
|
75 |
+
|
76 |
+
If `align_tf` is set as True, the model will predict 1008 classes, and
|
77 |
+
the pre-trained weight from `https://github.com/mseitzer/pytorch-fid`
|
78 |
+
will be loaded. Otherwise, the model will predict 1000 classes, and will
|
79 |
+
load the model from `torchvision`.
|
80 |
+
|
81 |
+
The built model supports following arguments when forwarding:
|
82 |
+
|
83 |
+
- transform_input: Whether to transform the input back to pixel range
|
84 |
+
(-1, 1). Please disable this argument if your input is already with
|
85 |
+
pixel range (-1, 1). (default: False)
|
86 |
+
- output_logits: Whether to output the categorical logits instead of
|
87 |
+
features. (default: False)
|
88 |
+
- remove_logits_bias: Whether to remove the bias when computing the
|
89 |
+
logits. The official implementation removes the bias by default.
|
90 |
+
Please refer to
|
91 |
+
`https://github.com/openai/improved-gan/blob/master/inception_score/model.py`.
|
92 |
+
(default: False)
|
93 |
+
- output_predictions: Whether to output the final predictions, i.e.,
|
94 |
+
`softmax(logits)`. (default: False)
|
95 |
+
"""
|
96 |
+
if align_tf:
|
97 |
+
num_classes = 1008
|
98 |
+
model_source = 'tf_inception_v3'
|
99 |
+
else:
|
100 |
+
num_classes = 1000
|
101 |
+
model_source = 'torchvision_official'
|
102 |
+
|
103 |
+
fingerprint = model_source
|
104 |
+
|
105 |
+
if fingerprint not in InceptionModel.models:
|
106 |
+
# Build model.
|
107 |
+
model = Inception3(num_classes=num_classes,
|
108 |
+
aux_logits=False,
|
109 |
+
init_weights=False,
|
110 |
+
align_tf=align_tf)
|
111 |
+
|
112 |
+
# Download pre-trained weights.
|
113 |
+
if dist.is_initialized() and dist.get_rank() != 0:
|
114 |
+
dist.barrier() # Download by chief.
|
115 |
+
|
116 |
+
url, sha256 = _MODEL_URL_SHA256[model_source]
|
117 |
+
filename = f'inception_model_{model_source}_{sha256}.pth'
|
118 |
+
model_path, hash_check = download_url(url,
|
119 |
+
filename=filename,
|
120 |
+
sha256=sha256)
|
121 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
122 |
+
if hash_check is False:
|
123 |
+
warnings.warn(f'Hash check failed! The remote file from URL '
|
124 |
+
f'`{url}` may be changed, or the downloading is '
|
125 |
+
f'interrupted. The loaded inception model may '
|
126 |
+
f'have unexpected behavior.')
|
127 |
+
|
128 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
129 |
+
dist.barrier() # Wait for other replicas.
|
130 |
+
|
131 |
+
# Load weights.
|
132 |
+
model.load_state_dict(state_dict, strict=False)
|
133 |
+
del state_dict
|
134 |
+
|
135 |
+
# For inference only.
|
136 |
+
model.eval().requires_grad_(False).cuda()
|
137 |
+
InceptionModel.models[fingerprint] = model
|
138 |
+
|
139 |
+
return InceptionModel.models[fingerprint]
|
140 |
+
|
141 |
+
# pylint: disable=missing-function-docstring
|
142 |
+
# pylint: disable=missing-class-docstring
|
143 |
+
# pylint: disable=super-with-arguments
|
144 |
+
# pylint: disable=consider-merging-isinstance
|
145 |
+
# pylint: disable=import-outside-toplevel
|
146 |
+
# pylint: disable=no-else-return
|
147 |
+
|
148 |
+
class Inception3(nn.Module):
|
149 |
+
|
150 |
+
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None,
|
151 |
+
init_weights=True, align_tf=True):
|
152 |
+
super(Inception3, self).__init__()
|
153 |
+
if inception_blocks is None:
|
154 |
+
inception_blocks = [
|
155 |
+
BasicConv2d, InceptionA, InceptionB, InceptionC,
|
156 |
+
InceptionD, InceptionE, InceptionAux
|
157 |
+
]
|
158 |
+
assert len(inception_blocks) == 7
|
159 |
+
conv_block = inception_blocks[0]
|
160 |
+
inception_a = inception_blocks[1]
|
161 |
+
inception_b = inception_blocks[2]
|
162 |
+
inception_c = inception_blocks[3]
|
163 |
+
inception_d = inception_blocks[4]
|
164 |
+
inception_e = inception_blocks[5]
|
165 |
+
inception_aux = inception_blocks[6]
|
166 |
+
|
167 |
+
self.aux_logits = aux_logits
|
168 |
+
self.align_tf = align_tf
|
169 |
+
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
|
170 |
+
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
171 |
+
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
172 |
+
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
173 |
+
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
174 |
+
self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
|
175 |
+
self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
|
176 |
+
self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
|
177 |
+
self.Mixed_6a = inception_b(288)
|
178 |
+
self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
|
179 |
+
self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
180 |
+
self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
181 |
+
self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
|
182 |
+
if aux_logits:
|
183 |
+
self.AuxLogits = inception_aux(768, num_classes)
|
184 |
+
self.Mixed_7a = inception_d(768)
|
185 |
+
self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
|
186 |
+
self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
|
187 |
+
self.fc = nn.Linear(2048, num_classes)
|
188 |
+
if init_weights:
|
189 |
+
for m in self.modules():
|
190 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
191 |
+
import scipy.stats as stats
|
192 |
+
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
193 |
+
X = stats.truncnorm(-2, 2, scale=stddev)
|
194 |
+
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
|
195 |
+
values = values.view(m.weight.size())
|
196 |
+
with torch.no_grad():
|
197 |
+
m.weight.copy_(values)
|
198 |
+
elif isinstance(m, nn.BatchNorm2d):
|
199 |
+
nn.init.constant_(m.weight, 1)
|
200 |
+
nn.init.constant_(m.bias, 0)
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def _transform_input(x, transform_input=False):
|
204 |
+
if transform_input:
|
205 |
+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
206 |
+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
207 |
+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
208 |
+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
|
209 |
+
return x
|
210 |
+
|
211 |
+
def _forward(self,
|
212 |
+
x,
|
213 |
+
output_logits=False,
|
214 |
+
remove_logits_bias=False,
|
215 |
+
output_predictions=False):
|
216 |
+
# Upsample if necessary.
|
217 |
+
if x.shape[2] != 299 or x.shape[3] != 299:
|
218 |
+
if self.align_tf:
|
219 |
+
theta = torch.eye(2, 3).to(x)
|
220 |
+
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299
|
221 |
+
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299
|
222 |
+
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
223 |
+
grid = F.affine_grid(theta,
|
224 |
+
size=(x.shape[0], x.shape[1], 299, 299),
|
225 |
+
align_corners=False)
|
226 |
+
x = F.grid_sample(x, grid,
|
227 |
+
mode='bilinear',
|
228 |
+
padding_mode='border',
|
229 |
+
align_corners=False)
|
230 |
+
else:
|
231 |
+
x = F.interpolate(
|
232 |
+
x, size=(299, 299), mode='bilinear', align_corners=False)
|
233 |
+
if x.shape[1] == 1:
|
234 |
+
x = x.repeat((1, 3, 1, 1))
|
235 |
+
|
236 |
+
if self.align_tf:
|
237 |
+
x = (x * 127.5 + 127.5 - 128) / 128
|
238 |
+
|
239 |
+
# N x 3 x 299 x 299
|
240 |
+
x = self.Conv2d_1a_3x3(x)
|
241 |
+
# N x 32 x 149 x 149
|
242 |
+
x = self.Conv2d_2a_3x3(x)
|
243 |
+
# N x 32 x 147 x 147
|
244 |
+
x = self.Conv2d_2b_3x3(x)
|
245 |
+
# N x 64 x 147 x 147
|
246 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
247 |
+
# N x 64 x 73 x 73
|
248 |
+
x = self.Conv2d_3b_1x1(x)
|
249 |
+
# N x 80 x 73 x 73
|
250 |
+
x = self.Conv2d_4a_3x3(x)
|
251 |
+
# N x 192 x 71 x 71
|
252 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
253 |
+
# N x 192 x 35 x 35
|
254 |
+
x = self.Mixed_5b(x)
|
255 |
+
# N x 256 x 35 x 35
|
256 |
+
x = self.Mixed_5c(x)
|
257 |
+
# N x 288 x 35 x 35
|
258 |
+
x = self.Mixed_5d(x)
|
259 |
+
# N x 288 x 35 x 35
|
260 |
+
x = self.Mixed_6a(x)
|
261 |
+
# N x 768 x 17 x 17
|
262 |
+
x = self.Mixed_6b(x)
|
263 |
+
# N x 768 x 17 x 17
|
264 |
+
x = self.Mixed_6c(x)
|
265 |
+
# N x 768 x 17 x 17
|
266 |
+
x = self.Mixed_6d(x)
|
267 |
+
# N x 768 x 17 x 17
|
268 |
+
x = self.Mixed_6e(x)
|
269 |
+
# N x 768 x 17 x 17
|
270 |
+
if self.training and self.aux_logits:
|
271 |
+
aux = self.AuxLogits(x)
|
272 |
+
else:
|
273 |
+
aux = None
|
274 |
+
# N x 768 x 17 x 17
|
275 |
+
x = self.Mixed_7a(x)
|
276 |
+
# N x 1280 x 8 x 8
|
277 |
+
x = self.Mixed_7b(x)
|
278 |
+
# N x 2048 x 8 x 8
|
279 |
+
x = self.Mixed_7c(x)
|
280 |
+
# N x 2048 x 8 x 8
|
281 |
+
# Adaptive average pooling
|
282 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
283 |
+
# N x 2048 x 1 x 1
|
284 |
+
x = F.dropout(x, training=self.training)
|
285 |
+
# N x 2048 x 1 x 1
|
286 |
+
x = torch.flatten(x, 1)
|
287 |
+
# N x 2048
|
288 |
+
if output_logits or output_predictions:
|
289 |
+
x = self.fc(x)
|
290 |
+
# N x 1000 (num_classes)
|
291 |
+
if remove_logits_bias:
|
292 |
+
x = x - self.fc.bias.view(1, -1)
|
293 |
+
if output_predictions:
|
294 |
+
x = F.softmax(x, dim=1)
|
295 |
+
return x, aux
|
296 |
+
|
297 |
+
def forward(self,
|
298 |
+
x,
|
299 |
+
transform_input=False,
|
300 |
+
output_logits=False,
|
301 |
+
remove_logits_bias=False,
|
302 |
+
output_predictions=False):
|
303 |
+
x = self._transform_input(x, transform_input)
|
304 |
+
x, aux = self._forward(
|
305 |
+
x, output_logits, remove_logits_bias, output_predictions)
|
306 |
+
if self.training and self.aux_logits:
|
307 |
+
return x, aux
|
308 |
+
else:
|
309 |
+
return x
|
310 |
+
|
311 |
+
|
312 |
+
class InceptionA(nn.Module):
|
313 |
+
|
314 |
+
def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
|
315 |
+
super(InceptionA, self).__init__()
|
316 |
+
if conv_block is None:
|
317 |
+
conv_block = BasicConv2d
|
318 |
+
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
319 |
+
|
320 |
+
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
321 |
+
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
322 |
+
|
323 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
324 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
325 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
326 |
+
|
327 |
+
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
328 |
+
self.pool_include_padding = not align_tf
|
329 |
+
|
330 |
+
def _forward(self, x):
|
331 |
+
branch1x1 = self.branch1x1(x)
|
332 |
+
|
333 |
+
branch5x5 = self.branch5x5_1(x)
|
334 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
335 |
+
|
336 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
337 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
338 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
339 |
+
|
340 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
341 |
+
count_include_pad=self.pool_include_padding)
|
342 |
+
branch_pool = self.branch_pool(branch_pool)
|
343 |
+
|
344 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
345 |
+
return outputs
|
346 |
+
|
347 |
+
def forward(self, x):
|
348 |
+
outputs = self._forward(x)
|
349 |
+
return torch.cat(outputs, 1)
|
350 |
+
|
351 |
+
|
352 |
+
class InceptionB(nn.Module):
|
353 |
+
|
354 |
+
def __init__(self, in_channels, conv_block=None):
|
355 |
+
super(InceptionB, self).__init__()
|
356 |
+
if conv_block is None:
|
357 |
+
conv_block = BasicConv2d
|
358 |
+
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
359 |
+
|
360 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
361 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
362 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
363 |
+
|
364 |
+
def _forward(self, x):
|
365 |
+
branch3x3 = self.branch3x3(x)
|
366 |
+
|
367 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
368 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
369 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
370 |
+
|
371 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
372 |
+
|
373 |
+
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
374 |
+
return outputs
|
375 |
+
|
376 |
+
def forward(self, x):
|
377 |
+
outputs = self._forward(x)
|
378 |
+
return torch.cat(outputs, 1)
|
379 |
+
|
380 |
+
|
381 |
+
class InceptionC(nn.Module):
|
382 |
+
|
383 |
+
def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
|
384 |
+
super(InceptionC, self).__init__()
|
385 |
+
if conv_block is None:
|
386 |
+
conv_block = BasicConv2d
|
387 |
+
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
388 |
+
|
389 |
+
c7 = channels_7x7
|
390 |
+
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
391 |
+
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
392 |
+
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
393 |
+
|
394 |
+
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
395 |
+
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
396 |
+
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
397 |
+
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
398 |
+
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
399 |
+
|
400 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
401 |
+
self.pool_include_padding = not align_tf
|
402 |
+
|
403 |
+
def _forward(self, x):
|
404 |
+
branch1x1 = self.branch1x1(x)
|
405 |
+
|
406 |
+
branch7x7 = self.branch7x7_1(x)
|
407 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
408 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
409 |
+
|
410 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
411 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
412 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
413 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
414 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
415 |
+
|
416 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
417 |
+
count_include_pad=self.pool_include_padding)
|
418 |
+
branch_pool = self.branch_pool(branch_pool)
|
419 |
+
|
420 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
421 |
+
return outputs
|
422 |
+
|
423 |
+
def forward(self, x):
|
424 |
+
outputs = self._forward(x)
|
425 |
+
return torch.cat(outputs, 1)
|
426 |
+
|
427 |
+
|
428 |
+
class InceptionD(nn.Module):
|
429 |
+
|
430 |
+
def __init__(self, in_channels, conv_block=None):
|
431 |
+
super(InceptionD, self).__init__()
|
432 |
+
if conv_block is None:
|
433 |
+
conv_block = BasicConv2d
|
434 |
+
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
435 |
+
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
436 |
+
|
437 |
+
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
438 |
+
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
439 |
+
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
440 |
+
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
441 |
+
|
442 |
+
def _forward(self, x):
|
443 |
+
branch3x3 = self.branch3x3_1(x)
|
444 |
+
branch3x3 = self.branch3x3_2(branch3x3)
|
445 |
+
|
446 |
+
branch7x7x3 = self.branch7x7x3_1(x)
|
447 |
+
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
448 |
+
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
449 |
+
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
450 |
+
|
451 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
452 |
+
outputs = [branch3x3, branch7x7x3, branch_pool]
|
453 |
+
return outputs
|
454 |
+
|
455 |
+
def forward(self, x):
|
456 |
+
outputs = self._forward(x)
|
457 |
+
return torch.cat(outputs, 1)
|
458 |
+
|
459 |
+
|
460 |
+
class InceptionE(nn.Module):
|
461 |
+
|
462 |
+
def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
|
463 |
+
super(InceptionE, self).__init__()
|
464 |
+
if conv_block is None:
|
465 |
+
conv_block = BasicConv2d
|
466 |
+
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
467 |
+
|
468 |
+
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
469 |
+
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
470 |
+
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
471 |
+
|
472 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
473 |
+
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
474 |
+
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
475 |
+
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
476 |
+
|
477 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
478 |
+
self.pool_include_padding = not align_tf
|
479 |
+
self.use_max_pool = use_max_pool
|
480 |
+
|
481 |
+
def _forward(self, x):
|
482 |
+
branch1x1 = self.branch1x1(x)
|
483 |
+
|
484 |
+
branch3x3 = self.branch3x3_1(x)
|
485 |
+
branch3x3 = [
|
486 |
+
self.branch3x3_2a(branch3x3),
|
487 |
+
self.branch3x3_2b(branch3x3),
|
488 |
+
]
|
489 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
490 |
+
|
491 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
492 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
493 |
+
branch3x3dbl = [
|
494 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
495 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
496 |
+
]
|
497 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
498 |
+
|
499 |
+
if self.use_max_pool:
|
500 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
501 |
+
else:
|
502 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
503 |
+
count_include_pad=self.pool_include_padding)
|
504 |
+
branch_pool = self.branch_pool(branch_pool)
|
505 |
+
|
506 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
507 |
+
return outputs
|
508 |
+
|
509 |
+
def forward(self, x):
|
510 |
+
outputs = self._forward(x)
|
511 |
+
return torch.cat(outputs, 1)
|
512 |
+
|
513 |
+
|
514 |
+
class InceptionAux(nn.Module):
|
515 |
+
|
516 |
+
def __init__(self, in_channels, num_classes, conv_block=None):
|
517 |
+
super(InceptionAux, self).__init__()
|
518 |
+
if conv_block is None:
|
519 |
+
conv_block = BasicConv2d
|
520 |
+
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
521 |
+
self.conv1 = conv_block(128, 768, kernel_size=5)
|
522 |
+
self.conv1.stddev = 0.01
|
523 |
+
self.fc = nn.Linear(768, num_classes)
|
524 |
+
self.fc.stddev = 0.001
|
525 |
+
|
526 |
+
def forward(self, x):
|
527 |
+
# N x 768 x 17 x 17
|
528 |
+
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
529 |
+
# N x 768 x 5 x 5
|
530 |
+
x = self.conv0(x)
|
531 |
+
# N x 128 x 5 x 5
|
532 |
+
x = self.conv1(x)
|
533 |
+
# N x 768 x 1 x 1
|
534 |
+
# Adaptive average pooling
|
535 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
536 |
+
# N x 768 x 1 x 1
|
537 |
+
x = torch.flatten(x, 1)
|
538 |
+
# N x 768
|
539 |
+
x = self.fc(x)
|
540 |
+
# N x 1000
|
541 |
+
return x
|
542 |
+
|
543 |
+
|
544 |
+
class BasicConv2d(nn.Module):
|
545 |
+
|
546 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
547 |
+
super(BasicConv2d, self).__init__()
|
548 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
549 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
550 |
+
|
551 |
+
def forward(self, x):
|
552 |
+
x = self.conv(x)
|
553 |
+
x = self.bn(x)
|
554 |
+
return F.relu(x, inplace=True)
|
555 |
+
|
556 |
+
# pylint: enable=line-too-long
|
557 |
+
# pylint: enable=missing-function-docstring
|
558 |
+
# pylint: enable=missing-class-docstring
|
559 |
+
# pylint: enable=super-with-arguments
|
560 |
+
# pylint: enable=consider-merging-isinstance
|
561 |
+
# pylint: enable=import-outside-toplevel
|
562 |
+
# pylint: enable=no-else-return
|
models/perceptual_model.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the VGG16 model, which is used for inference ONLY.
|
3 |
+
|
4 |
+
VGG16 is commonly used for perceptual feature extraction. The model implemented
|
5 |
+
in this file can be used for evaluation (like computing LPIPS, perceptual path
|
6 |
+
length, etc.), OR be used in training for loss computation (like perceptual
|
7 |
+
loss, etc.).
|
8 |
+
|
9 |
+
The pre-trained model is officially shared by
|
10 |
+
|
11 |
+
https://www.robots.ox.ac.uk/~vgg/research/very_deep/
|
12 |
+
|
13 |
+
and ported by
|
14 |
+
|
15 |
+
https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
|
16 |
+
|
17 |
+
Compared to the official VGG16 model, this ported model also support evaluating
|
18 |
+
LPIPS, which is introduced in
|
19 |
+
|
20 |
+
https://github.com/richzhang/PerceptualSimilarity
|
21 |
+
"""
|
22 |
+
|
23 |
+
import warnings
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.distributed as dist
|
30 |
+
|
31 |
+
from utils.misc import download_url
|
32 |
+
|
33 |
+
__all__ = ['PerceptualModel']
|
34 |
+
|
35 |
+
# pylint: disable=line-too-long
|
36 |
+
_MODEL_URL_SHA256 = {
|
37 |
+
# This model is provided by `torchvision`, which is ported from TensorFlow.
|
38 |
+
'torchvision_official': (
|
39 |
+
'https://download.pytorch.org/models/vgg16-397923af.pth',
|
40 |
+
'397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256
|
41 |
+
),
|
42 |
+
|
43 |
+
# This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch
|
44 |
+
'vgg_perceptual_lpips': (
|
45 |
+
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt',
|
46 |
+
'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256
|
47 |
+
)
|
48 |
+
}
|
49 |
+
# pylint: enable=line-too-long
|
50 |
+
|
51 |
+
|
52 |
+
class PerceptualModel(object):
|
53 |
+
"""Defines the perceptual model, which is based on VGG16 structure.
|
54 |
+
|
55 |
+
This is a static class, which is used to avoid this model to be built
|
56 |
+
repeatedly. Consequently, this model is particularly used for inference,
|
57 |
+
like computing LPIPS, or for loss computation, like perceptual loss. If
|
58 |
+
training is required, please use the model from `torchvision.models` or
|
59 |
+
implement by yourself.
|
60 |
+
|
61 |
+
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
|
62 |
+
order and pixel range [-1, 1], and will NOT resize the input automatically
|
63 |
+
if only perceptual feature is needed.
|
64 |
+
"""
|
65 |
+
models = dict()
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def build_model(use_torchvision=False, no_top=True, enable_lpips=True):
|
69 |
+
"""Builds the model and load pre-trained weights.
|
70 |
+
|
71 |
+
1. If `use_torchvision` is set as True, the model released by
|
72 |
+
`torchvision` will be loaded, otherwise, the model released by
|
73 |
+
https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used.
|
74 |
+
(default: False)
|
75 |
+
|
76 |
+
2. To save computing resources, these is an option to only load the
|
77 |
+
backbone (i.e., without the last three fully-connected layers). This
|
78 |
+
is commonly used for perceptual loss or LPIPS loss computation.
|
79 |
+
Please use argument `no_top` to control this. (default: True)
|
80 |
+
|
81 |
+
3. For LPIPS loss computation, some additional weights (which is used
|
82 |
+
for balancing the features from different resolutions) are employed
|
83 |
+
on top of the original VGG16 backbone. Details can be found at
|
84 |
+
https://github.com/richzhang/PerceptualSimilarity. Please use
|
85 |
+
`enable_lpips` to enable this feature. (default: True)
|
86 |
+
|
87 |
+
The built model supports following arguments when forwarding:
|
88 |
+
|
89 |
+
- resize_input: Whether to resize the input image to size [224, 224]
|
90 |
+
before forwarding. For feature-based computation (i.e., only
|
91 |
+
convolutional layers are used), image resizing is not essential.
|
92 |
+
(default: False)
|
93 |
+
- return_tensor: This field resolves the model behavior. Following
|
94 |
+
options are supported:
|
95 |
+
`feature1`: Before the first max pooling layer.
|
96 |
+
`pool1`: After the first max pooling layer.
|
97 |
+
`feature2`: Before the second max pooling layer.
|
98 |
+
`pool2`: After the second max pooling layer.
|
99 |
+
`feature3`: Before the third max pooling layer.
|
100 |
+
`pool3`: After the third max pooling layer.
|
101 |
+
`feature4`: Before the fourth max pooling layer.
|
102 |
+
`pool4`: After the fourth max pooling layer.
|
103 |
+
`feature5`: Before the fifth max pooling layer.
|
104 |
+
`pool5`: After the fifth max pooling layer.
|
105 |
+
`flatten`: The flattened feature, after `adaptive_avgpool`.
|
106 |
+
`feature`: The 4096d feature for logits computation. (default)
|
107 |
+
`logits`: The 1000d categorical logits.
|
108 |
+
`prediction`: The 1000d predicted probability.
|
109 |
+
`lpips`: The LPIPS score between two input images.
|
110 |
+
"""
|
111 |
+
if use_torchvision:
|
112 |
+
model_source = 'torchvision_official'
|
113 |
+
align_tf_resize = False
|
114 |
+
is_torch_script = False
|
115 |
+
else:
|
116 |
+
model_source = 'vgg_perceptual_lpips'
|
117 |
+
align_tf_resize = True
|
118 |
+
is_torch_script = True
|
119 |
+
|
120 |
+
if enable_lpips and model_source != 'vgg_perceptual_lpips':
|
121 |
+
warnings.warn('The pre-trained model officially released by '
|
122 |
+
'`torchvision` does not support LPIPS computation! '
|
123 |
+
'Equal weights will be used for each resolution.')
|
124 |
+
|
125 |
+
fingerprint = (model_source, no_top, enable_lpips)
|
126 |
+
|
127 |
+
if fingerprint not in PerceptualModel.models:
|
128 |
+
# Build model.
|
129 |
+
model = VGG16(align_tf_resize=align_tf_resize,
|
130 |
+
no_top=no_top,
|
131 |
+
enable_lpips=enable_lpips)
|
132 |
+
|
133 |
+
# Download pre-trained weights.
|
134 |
+
if dist.is_initialized() and dist.get_rank() != 0:
|
135 |
+
dist.barrier() # Download by chief.
|
136 |
+
|
137 |
+
url, sha256 = _MODEL_URL_SHA256[model_source]
|
138 |
+
filename = f'perceptual_model_{model_source}_{sha256}.pth'
|
139 |
+
model_path, hash_check = download_url(url,
|
140 |
+
filename=filename,
|
141 |
+
sha256=sha256)
|
142 |
+
if is_torch_script:
|
143 |
+
src_state_dict = torch.jit.load(model_path, map_location='cpu')
|
144 |
+
else:
|
145 |
+
src_state_dict = torch.load(model_path, map_location='cpu')
|
146 |
+
if hash_check is False:
|
147 |
+
warnings.warn(f'Hash check failed! The remote file from URL '
|
148 |
+
f'`{url}` may be changed, or the downloading is '
|
149 |
+
f'interrupted. The loaded perceptual model may '
|
150 |
+
f'have unexpected behavior.')
|
151 |
+
|
152 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
153 |
+
dist.barrier() # Wait for other replicas.
|
154 |
+
|
155 |
+
# Load weights.
|
156 |
+
dst_state_dict = _convert_weights(src_state_dict, model_source)
|
157 |
+
model.load_state_dict(dst_state_dict, strict=False)
|
158 |
+
del src_state_dict, dst_state_dict
|
159 |
+
|
160 |
+
# For inference only.
|
161 |
+
model.eval().requires_grad_(False).cuda()
|
162 |
+
PerceptualModel.models[fingerprint] = model
|
163 |
+
|
164 |
+
return PerceptualModel.models[fingerprint]
|
165 |
+
|
166 |
+
|
167 |
+
def _convert_weights(src_state_dict, model_source):
|
168 |
+
if model_source not in _MODEL_URL_SHA256:
|
169 |
+
raise ValueError(f'Invalid model source `{model_source}`!\n'
|
170 |
+
f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.')
|
171 |
+
if model_source == 'torchvision_official':
|
172 |
+
dst_to_src_var_mapping = {
|
173 |
+
'conv11.weight': 'features.0.weight',
|
174 |
+
'conv11.bias': 'features.0.bias',
|
175 |
+
'conv12.weight': 'features.2.weight',
|
176 |
+
'conv12.bias': 'features.2.bias',
|
177 |
+
'conv21.weight': 'features.5.weight',
|
178 |
+
'conv21.bias': 'features.5.bias',
|
179 |
+
'conv22.weight': 'features.7.weight',
|
180 |
+
'conv22.bias': 'features.7.bias',
|
181 |
+
'conv31.weight': 'features.10.weight',
|
182 |
+
'conv31.bias': 'features.10.bias',
|
183 |
+
'conv32.weight': 'features.12.weight',
|
184 |
+
'conv32.bias': 'features.12.bias',
|
185 |
+
'conv33.weight': 'features.14.weight',
|
186 |
+
'conv33.bias': 'features.14.bias',
|
187 |
+
'conv41.weight': 'features.17.weight',
|
188 |
+
'conv41.bias': 'features.17.bias',
|
189 |
+
'conv42.weight': 'features.19.weight',
|
190 |
+
'conv42.bias': 'features.19.bias',
|
191 |
+
'conv43.weight': 'features.21.weight',
|
192 |
+
'conv43.bias': 'features.21.bias',
|
193 |
+
'conv51.weight': 'features.24.weight',
|
194 |
+
'conv51.bias': 'features.24.bias',
|
195 |
+
'conv52.weight': 'features.26.weight',
|
196 |
+
'conv52.bias': 'features.26.bias',
|
197 |
+
'conv53.weight': 'features.28.weight',
|
198 |
+
'conv53.bias': 'features.28.bias',
|
199 |
+
'fc1.weight': 'classifier.0.weight',
|
200 |
+
'fc1.bias': 'classifier.0.bias',
|
201 |
+
'fc2.weight': 'classifier.3.weight',
|
202 |
+
'fc2.bias': 'classifier.3.bias',
|
203 |
+
'fc3.weight': 'classifier.6.weight',
|
204 |
+
'fc3.bias': 'classifier.6.bias',
|
205 |
+
}
|
206 |
+
elif model_source == 'vgg_perceptual_lpips':
|
207 |
+
src_state_dict = src_state_dict.state_dict()
|
208 |
+
dst_to_src_var_mapping = {
|
209 |
+
'conv11.weight': 'layers.conv1.weight',
|
210 |
+
'conv11.bias': 'layers.conv1.bias',
|
211 |
+
'conv12.weight': 'layers.conv2.weight',
|
212 |
+
'conv12.bias': 'layers.conv2.bias',
|
213 |
+
'conv21.weight': 'layers.conv3.weight',
|
214 |
+
'conv21.bias': 'layers.conv3.bias',
|
215 |
+
'conv22.weight': 'layers.conv4.weight',
|
216 |
+
'conv22.bias': 'layers.conv4.bias',
|
217 |
+
'conv31.weight': 'layers.conv5.weight',
|
218 |
+
'conv31.bias': 'layers.conv5.bias',
|
219 |
+
'conv32.weight': 'layers.conv6.weight',
|
220 |
+
'conv32.bias': 'layers.conv6.bias',
|
221 |
+
'conv33.weight': 'layers.conv7.weight',
|
222 |
+
'conv33.bias': 'layers.conv7.bias',
|
223 |
+
'conv41.weight': 'layers.conv8.weight',
|
224 |
+
'conv41.bias': 'layers.conv8.bias',
|
225 |
+
'conv42.weight': 'layers.conv9.weight',
|
226 |
+
'conv42.bias': 'layers.conv9.bias',
|
227 |
+
'conv43.weight': 'layers.conv10.weight',
|
228 |
+
'conv43.bias': 'layers.conv10.bias',
|
229 |
+
'conv51.weight': 'layers.conv11.weight',
|
230 |
+
'conv51.bias': 'layers.conv11.bias',
|
231 |
+
'conv52.weight': 'layers.conv12.weight',
|
232 |
+
'conv52.bias': 'layers.conv12.bias',
|
233 |
+
'conv53.weight': 'layers.conv13.weight',
|
234 |
+
'conv53.bias': 'layers.conv13.bias',
|
235 |
+
'fc1.weight': 'layers.fc1.weight',
|
236 |
+
'fc1.bias': 'layers.fc1.bias',
|
237 |
+
'fc2.weight': 'layers.fc2.weight',
|
238 |
+
'fc2.bias': 'layers.fc2.bias',
|
239 |
+
'fc3.weight': 'layers.fc3.weight',
|
240 |
+
'fc3.bias': 'layers.fc3.bias',
|
241 |
+
'lpips.0.weight': 'lpips0',
|
242 |
+
'lpips.1.weight': 'lpips1',
|
243 |
+
'lpips.2.weight': 'lpips2',
|
244 |
+
'lpips.3.weight': 'lpips3',
|
245 |
+
'lpips.4.weight': 'lpips4',
|
246 |
+
}
|
247 |
+
else:
|
248 |
+
raise NotImplementedError(f'Not implemented model source '
|
249 |
+
f'`{model_source}`!')
|
250 |
+
|
251 |
+
dst_state_dict = {}
|
252 |
+
for dst_name, src_name in dst_to_src_var_mapping.items():
|
253 |
+
if dst_name.startswith('lpips'):
|
254 |
+
dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0)
|
255 |
+
else:
|
256 |
+
dst_state_dict[dst_name] = src_state_dict[src_name].clone()
|
257 |
+
return dst_state_dict
|
258 |
+
|
259 |
+
|
260 |
+
_IMG_MEAN = (0.485, 0.456, 0.406)
|
261 |
+
_IMG_STD = (0.229, 0.224, 0.225)
|
262 |
+
_ALLOWED_RETURN = [
|
263 |
+
'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4',
|
264 |
+
'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction',
|
265 |
+
'lpips'
|
266 |
+
]
|
267 |
+
|
268 |
+
# pylint: disable=missing-function-docstring
|
269 |
+
|
270 |
+
class VGG16(nn.Module):
|
271 |
+
"""Defines the VGG16 structure.
|
272 |
+
|
273 |
+
This model takes `RGB` images with data format `NCHW` as the raw inputs. The
|
274 |
+
pixel range are assumed to be [-1, 1].
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True):
|
278 |
+
"""Defines the network structure."""
|
279 |
+
super().__init__()
|
280 |
+
|
281 |
+
self.align_tf_resize = align_tf_resize
|
282 |
+
self.no_top = no_top
|
283 |
+
self.enable_lpips = enable_lpips
|
284 |
+
|
285 |
+
self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
286 |
+
self.relu11 = nn.ReLU(inplace=True)
|
287 |
+
self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
288 |
+
self.relu12 = nn.ReLU(inplace=True)
|
289 |
+
# output `feature1`, with shape [N, 64, 224, 224]
|
290 |
+
|
291 |
+
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
292 |
+
# output `pool1`, with shape [N, 64, 112, 112]
|
293 |
+
|
294 |
+
self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
295 |
+
self.relu21 = nn.ReLU(inplace=True)
|
296 |
+
self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
297 |
+
self.relu22 = nn.ReLU(inplace=True)
|
298 |
+
# output `feature2`, with shape [N, 128, 112, 112]
|
299 |
+
|
300 |
+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
301 |
+
# output `pool2`, with shape [N, 128, 56, 56]
|
302 |
+
|
303 |
+
self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
304 |
+
self.relu31 = nn.ReLU(inplace=True)
|
305 |
+
self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
306 |
+
self.relu32 = nn.ReLU(inplace=True)
|
307 |
+
self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
308 |
+
self.relu33 = nn.ReLU(inplace=True)
|
309 |
+
# output `feature3`, with shape [N, 256, 56, 56]
|
310 |
+
|
311 |
+
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
312 |
+
# output `pool3`, with shape [N,256, 28, 28]
|
313 |
+
|
314 |
+
self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
315 |
+
self.relu41 = nn.ReLU(inplace=True)
|
316 |
+
self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
317 |
+
self.relu42 = nn.ReLU(inplace=True)
|
318 |
+
self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
319 |
+
self.relu43 = nn.ReLU(inplace=True)
|
320 |
+
# output `feature4`, with shape [N, 512, 28, 28]
|
321 |
+
|
322 |
+
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
323 |
+
# output `pool4`, with shape [N, 512, 14, 14]
|
324 |
+
|
325 |
+
self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
326 |
+
self.relu51 = nn.ReLU(inplace=True)
|
327 |
+
self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
328 |
+
self.relu52 = nn.ReLU(inplace=True)
|
329 |
+
self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
330 |
+
self.relu53 = nn.ReLU(inplace=True)
|
331 |
+
# output `feature5`, with shape [N, 512, 14, 14]
|
332 |
+
|
333 |
+
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
334 |
+
# output `pool5`, with shape [N, 512, 7, 7]
|
335 |
+
|
336 |
+
if self.enable_lpips:
|
337 |
+
self.lpips = nn.ModuleList()
|
338 |
+
for idx, ch in enumerate([64, 128, 256, 512, 512]):
|
339 |
+
self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False))
|
340 |
+
self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1))
|
341 |
+
|
342 |
+
if not self.no_top:
|
343 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
344 |
+
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
|
345 |
+
# output `flatten`, with shape [N, 25088]
|
346 |
+
|
347 |
+
self.fc1 = nn.Linear(512 * 7 * 7, 4096)
|
348 |
+
self.fc1_relu = nn.ReLU(inplace=True)
|
349 |
+
self.fc1_dropout = nn.Dropout(0.5, inplace=False)
|
350 |
+
self.fc2 = nn.Linear(4096, 4096)
|
351 |
+
self.fc2_relu = nn.ReLU(inplace=True)
|
352 |
+
self.fc2_dropout = nn.Dropout(0.5, inplace=False)
|
353 |
+
# output `feature`, with shape [N, 4096]
|
354 |
+
|
355 |
+
self.fc3 = nn.Linear(4096, 1000)
|
356 |
+
# output `logits`, with shape [N, 1000]
|
357 |
+
|
358 |
+
self.out = nn.Softmax(dim=1)
|
359 |
+
# output `softmax`, with shape [N, 1000]
|
360 |
+
|
361 |
+
img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32)
|
362 |
+
img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32)
|
363 |
+
self.register_buffer('img_mean', torch.from_numpy(img_mean))
|
364 |
+
self.register_buffer('img_std', torch.from_numpy(img_std))
|
365 |
+
|
366 |
+
def forward(self,
|
367 |
+
x,
|
368 |
+
y=None,
|
369 |
+
*,
|
370 |
+
resize_input=False,
|
371 |
+
return_tensor='feature'):
|
372 |
+
return_tensor = return_tensor.lower()
|
373 |
+
if return_tensor not in _ALLOWED_RETURN:
|
374 |
+
raise ValueError(f'Invalid output tensor name `{return_tensor}` '
|
375 |
+
f'for perceptual model (VGG16)!\n'
|
376 |
+
f'Names allowed: {_ALLOWED_RETURN}.')
|
377 |
+
|
378 |
+
if return_tensor == 'lpips' and y is None:
|
379 |
+
raise ValueError('Two images are required for LPIPS computation, '
|
380 |
+
'but only one is received!')
|
381 |
+
|
382 |
+
if return_tensor == 'lpips':
|
383 |
+
assert x.shape == y.shape
|
384 |
+
x = torch.cat([x, y], dim=0)
|
385 |
+
features = []
|
386 |
+
|
387 |
+
if resize_input:
|
388 |
+
if self.align_tf_resize:
|
389 |
+
theta = torch.eye(2, 3).to(x)
|
390 |
+
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224
|
391 |
+
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224
|
392 |
+
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
393 |
+
grid = F.affine_grid(theta,
|
394 |
+
size=(x.shape[0], x.shape[1], 224, 224),
|
395 |
+
align_corners=False)
|
396 |
+
x = F.grid_sample(x, grid,
|
397 |
+
mode='bilinear',
|
398 |
+
padding_mode='border',
|
399 |
+
align_corners=False)
|
400 |
+
else:
|
401 |
+
x = F.interpolate(x,
|
402 |
+
size=(224, 224),
|
403 |
+
mode='bilinear',
|
404 |
+
align_corners=False)
|
405 |
+
if x.shape[1] == 1:
|
406 |
+
x = x.repeat((1, 3, 1, 1))
|
407 |
+
|
408 |
+
x = (x + 1) / 2
|
409 |
+
x = (x - self.img_mean) / self.img_std
|
410 |
+
|
411 |
+
x = self.conv11(x)
|
412 |
+
x = self.relu11(x)
|
413 |
+
x = self.conv12(x)
|
414 |
+
x = self.relu12(x)
|
415 |
+
if return_tensor == 'feature1':
|
416 |
+
return x
|
417 |
+
if return_tensor == 'lpips':
|
418 |
+
features.append(x)
|
419 |
+
|
420 |
+
x = self.pool1(x)
|
421 |
+
if return_tensor == 'pool1':
|
422 |
+
return x
|
423 |
+
|
424 |
+
x = self.conv21(x)
|
425 |
+
x = self.relu21(x)
|
426 |
+
x = self.conv22(x)
|
427 |
+
x = self.relu22(x)
|
428 |
+
if return_tensor == 'feature2':
|
429 |
+
return x
|
430 |
+
if return_tensor == 'lpips':
|
431 |
+
features.append(x)
|
432 |
+
|
433 |
+
x = self.pool2(x)
|
434 |
+
if return_tensor == 'pool2':
|
435 |
+
return x
|
436 |
+
|
437 |
+
x = self.conv31(x)
|
438 |
+
x = self.relu31(x)
|
439 |
+
x = self.conv32(x)
|
440 |
+
x = self.relu32(x)
|
441 |
+
x = self.conv33(x)
|
442 |
+
x = self.relu33(x)
|
443 |
+
if return_tensor == 'feature3':
|
444 |
+
return x
|
445 |
+
if return_tensor == 'lpips':
|
446 |
+
features.append(x)
|
447 |
+
|
448 |
+
x = self.pool3(x)
|
449 |
+
if return_tensor == 'pool3':
|
450 |
+
return x
|
451 |
+
|
452 |
+
x = self.conv41(x)
|
453 |
+
x = self.relu41(x)
|
454 |
+
x = self.conv42(x)
|
455 |
+
x = self.relu42(x)
|
456 |
+
x = self.conv43(x)
|
457 |
+
x = self.relu43(x)
|
458 |
+
if return_tensor == 'feature4':
|
459 |
+
return x
|
460 |
+
if return_tensor == 'lpips':
|
461 |
+
features.append(x)
|
462 |
+
|
463 |
+
x = self.pool4(x)
|
464 |
+
if return_tensor == 'pool4':
|
465 |
+
return x
|
466 |
+
|
467 |
+
x = self.conv51(x)
|
468 |
+
x = self.relu51(x)
|
469 |
+
x = self.conv52(x)
|
470 |
+
x = self.relu52(x)
|
471 |
+
x = self.conv53(x)
|
472 |
+
x = self.relu53(x)
|
473 |
+
if return_tensor == 'feature5':
|
474 |
+
return x
|
475 |
+
if return_tensor == 'lpips':
|
476 |
+
features.append(x)
|
477 |
+
|
478 |
+
x = self.pool5(x)
|
479 |
+
if return_tensor == 'pool5':
|
480 |
+
return x
|
481 |
+
|
482 |
+
if return_tensor == 'lpips':
|
483 |
+
score = 0
|
484 |
+
assert len(features) == 5
|
485 |
+
for idx in range(5):
|
486 |
+
feature = features[idx]
|
487 |
+
norm = feature.norm(dim=1, keepdim=True)
|
488 |
+
feature = feature / (norm + 1e-10)
|
489 |
+
feature_x, feature_y = feature.chunk(2, dim=0)
|
490 |
+
diff = (feature_x - feature_y).square()
|
491 |
+
score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False)
|
492 |
+
return score.sum(dim=1, keepdim=False)
|
493 |
+
|
494 |
+
x = self.avgpool(x)
|
495 |
+
x = self.flatten(x)
|
496 |
+
if return_tensor == 'flatten':
|
497 |
+
return x
|
498 |
+
|
499 |
+
x = self.fc1(x)
|
500 |
+
x = self.fc1_relu(x)
|
501 |
+
x = self.fc1_dropout(x)
|
502 |
+
x = self.fc2(x)
|
503 |
+
x = self.fc2_relu(x)
|
504 |
+
x = self.fc2_dropout(x)
|
505 |
+
if return_tensor == 'feature':
|
506 |
+
return x
|
507 |
+
|
508 |
+
x = self.fc3(x)
|
509 |
+
if return_tensor == 'logits':
|
510 |
+
return x
|
511 |
+
|
512 |
+
x = self.out(x)
|
513 |
+
if return_tensor == 'prediction':
|
514 |
+
return x
|
515 |
+
|
516 |
+
raise NotImplementedError(f'Output tensor name `{return_tensor}` is '
|
517 |
+
f'not implemented!')
|
518 |
+
|
519 |
+
# pylint: enable=missing-function-docstring
|
models/pggan_discriminator.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of discriminator described in PGGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1710.10196.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation:
|
7 |
+
https://github.com/tkarras/progressive_growing_of_gans
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
__all__ = ['PGGANDiscriminator']
|
17 |
+
|
18 |
+
# Resolutions allowed.
|
19 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
20 |
+
|
21 |
+
# Default gain factor for weight scaling.
|
22 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
23 |
+
|
24 |
+
# pylint: disable=missing-function-docstring
|
25 |
+
|
26 |
+
class PGGANDiscriminator(nn.Module):
|
27 |
+
"""Defines the discriminator network in PGGAN.
|
28 |
+
|
29 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
30 |
+
range [-1, 1] as inputs.
|
31 |
+
|
32 |
+
Settings for the network:
|
33 |
+
|
34 |
+
(1) resolution: The resolution of the input image.
|
35 |
+
(2) init_res: Smallest resolution of the convolutional backbone.
|
36 |
+
(default: 4)
|
37 |
+
(3) image_channels: Number of channels of the input image. (default: 3)
|
38 |
+
(4) label_dim: Dimension of the additional label for conditional generation.
|
39 |
+
In one-hot conditioning case, it is equal to the number of classes. If
|
40 |
+
set to 0, conditioning training will be disabled. (default: 0)
|
41 |
+
(5) fused_scale: Whether to fused `conv2d` and `downsample` together,
|
42 |
+
resulting in `conv2d` with strides. (default: False)
|
43 |
+
(6) use_wscale: Whether to use weight scaling. (default: True)
|
44 |
+
(7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
|
45 |
+
(8) mbstd_groups: Group size for the minibatch standard deviation layer.
|
46 |
+
`0` means disable. (default: 16)
|
47 |
+
(9) fmaps_base: Factor to control number of feature maps for each layer.
|
48 |
+
(default: 16 << 10)
|
49 |
+
(10) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
50 |
+
(11) eps: A small value to avoid divide overflow. (default: 1e-8)
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self,
|
54 |
+
resolution,
|
55 |
+
init_res=4,
|
56 |
+
image_channels=3,
|
57 |
+
label_dim=0,
|
58 |
+
fused_scale=False,
|
59 |
+
use_wscale=True,
|
60 |
+
wscale_gain=np.sqrt(2.0),
|
61 |
+
mbstd_groups=16,
|
62 |
+
fmaps_base=16 << 10,
|
63 |
+
fmaps_max=512,
|
64 |
+
eps=1e-8):
|
65 |
+
"""Initializes with basic settings.
|
66 |
+
|
67 |
+
Raises:
|
68 |
+
ValueError: If the `resolution` is not supported.
|
69 |
+
"""
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
73 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
74 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
75 |
+
|
76 |
+
self.init_res = init_res
|
77 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
78 |
+
self.resolution = resolution
|
79 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
80 |
+
self.image_channels = image_channels
|
81 |
+
self.label_dim = label_dim
|
82 |
+
self.fused_scale = fused_scale
|
83 |
+
self.use_wscale = use_wscale
|
84 |
+
self.wscale_gain = wscale_gain
|
85 |
+
self.mbstd_groups = mbstd_groups
|
86 |
+
self.fmaps_base = fmaps_base
|
87 |
+
self.fmaps_max = fmaps_max
|
88 |
+
self.eps = eps
|
89 |
+
|
90 |
+
# Level-of-details (used for progressive training).
|
91 |
+
self.register_buffer('lod', torch.zeros(()))
|
92 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
93 |
+
|
94 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
95 |
+
res = 2 ** res_log2
|
96 |
+
in_channels = self.get_nf(res)
|
97 |
+
out_channels = self.get_nf(res // 2)
|
98 |
+
block_idx = self.final_res_log2 - res_log2
|
99 |
+
|
100 |
+
# Input convolution layer for each resolution.
|
101 |
+
self.add_module(
|
102 |
+
f'input{block_idx}',
|
103 |
+
ConvLayer(in_channels=self.image_channels,
|
104 |
+
out_channels=in_channels,
|
105 |
+
kernel_size=1,
|
106 |
+
add_bias=True,
|
107 |
+
downsample=False,
|
108 |
+
fused_scale=False,
|
109 |
+
use_wscale=use_wscale,
|
110 |
+
wscale_gain=wscale_gain,
|
111 |
+
activation_type='lrelu'))
|
112 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
113 |
+
f'FromRGB_lod{block_idx}/weight')
|
114 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
115 |
+
f'FromRGB_lod{block_idx}/bias')
|
116 |
+
|
117 |
+
# Convolution block for each resolution (except the last one).
|
118 |
+
if res != self.init_res:
|
119 |
+
self.add_module(
|
120 |
+
f'layer{2 * block_idx}',
|
121 |
+
ConvLayer(in_channels=in_channels,
|
122 |
+
out_channels=in_channels,
|
123 |
+
kernel_size=3,
|
124 |
+
add_bias=True,
|
125 |
+
downsample=False,
|
126 |
+
fused_scale=False,
|
127 |
+
use_wscale=use_wscale,
|
128 |
+
wscale_gain=wscale_gain,
|
129 |
+
activation_type='lrelu'))
|
130 |
+
tf_layer0_name = 'Conv0'
|
131 |
+
self.add_module(
|
132 |
+
f'layer{2 * block_idx + 1}',
|
133 |
+
ConvLayer(in_channels=in_channels,
|
134 |
+
out_channels=out_channels,
|
135 |
+
kernel_size=3,
|
136 |
+
add_bias=True,
|
137 |
+
downsample=True,
|
138 |
+
fused_scale=fused_scale,
|
139 |
+
use_wscale=use_wscale,
|
140 |
+
wscale_gain=wscale_gain,
|
141 |
+
activation_type='lrelu'))
|
142 |
+
tf_layer1_name = 'Conv1_down' if fused_scale else 'Conv1'
|
143 |
+
|
144 |
+
# Convolution block for last resolution.
|
145 |
+
else:
|
146 |
+
self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups, eps=eps)
|
147 |
+
self.add_module(
|
148 |
+
f'layer{2 * block_idx}',
|
149 |
+
ConvLayer(
|
150 |
+
in_channels=in_channels + 1,
|
151 |
+
out_channels=in_channels,
|
152 |
+
kernel_size=3,
|
153 |
+
add_bias=True,
|
154 |
+
downsample=False,
|
155 |
+
fused_scale=False,
|
156 |
+
use_wscale=use_wscale,
|
157 |
+
wscale_gain=wscale_gain,
|
158 |
+
activation_type='lrelu'))
|
159 |
+
tf_layer0_name = 'Conv'
|
160 |
+
self.add_module(
|
161 |
+
f'layer{2 * block_idx + 1}',
|
162 |
+
DenseLayer(in_channels=in_channels * res * res,
|
163 |
+
out_channels=out_channels,
|
164 |
+
add_bias=True,
|
165 |
+
use_wscale=use_wscale,
|
166 |
+
wscale_gain=wscale_gain,
|
167 |
+
activation_type='lrelu'))
|
168 |
+
tf_layer1_name = 'Dense0'
|
169 |
+
|
170 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
171 |
+
f'{res}x{res}/{tf_layer0_name}/weight')
|
172 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
173 |
+
f'{res}x{res}/{tf_layer0_name}/bias')
|
174 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
175 |
+
f'{res}x{res}/{tf_layer1_name}/weight')
|
176 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
177 |
+
f'{res}x{res}/{tf_layer1_name}/bias')
|
178 |
+
|
179 |
+
# Final dense layer.
|
180 |
+
self.output = DenseLayer(in_channels=out_channels,
|
181 |
+
out_channels=1 + self.label_dim,
|
182 |
+
add_bias=True,
|
183 |
+
use_wscale=self.use_wscale,
|
184 |
+
wscale_gain=1.0,
|
185 |
+
activation_type='linear')
|
186 |
+
self.pth_to_tf_var_mapping['output.weight'] = (
|
187 |
+
f'{res}x{res}/Dense1/weight')
|
188 |
+
self.pth_to_tf_var_mapping['output.bias'] = (
|
189 |
+
f'{res}x{res}/Dense1/bias')
|
190 |
+
|
191 |
+
def get_nf(self, res):
|
192 |
+
"""Gets number of feature maps according to the given resolution."""
|
193 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
194 |
+
|
195 |
+
def forward(self, image, lod=None):
|
196 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
197 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
198 |
+
raise ValueError(f'The input tensor should be with shape '
|
199 |
+
f'[batch_size, channel, height, width], where '
|
200 |
+
f'`channel` equals to {self.image_channels}, '
|
201 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
202 |
+
f'But `{image.shape}` is received!')
|
203 |
+
|
204 |
+
lod = self.lod.item() if lod is None else lod
|
205 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
206 |
+
raise ValueError(f'Maximum level-of-details (lod) is '
|
207 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
208 |
+
f'but `{lod}` is received!')
|
209 |
+
|
210 |
+
lod = self.lod.item()
|
211 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
212 |
+
block_idx = current_lod = self.final_res_log2 - res_log2
|
213 |
+
if current_lod <= lod < current_lod + 1:
|
214 |
+
x = getattr(self, f'input{block_idx}')(image)
|
215 |
+
elif current_lod - 1 < lod < current_lod:
|
216 |
+
alpha = lod - np.floor(lod)
|
217 |
+
y = getattr(self, f'input{block_idx}')(image)
|
218 |
+
x = y * alpha + x * (1 - alpha)
|
219 |
+
if lod < current_lod + 1:
|
220 |
+
if res_log2 == self.init_res_log2:
|
221 |
+
x = self.mbstd(x)
|
222 |
+
x = getattr(self, f'layer{2 * block_idx}')(x)
|
223 |
+
x = getattr(self, f'layer{2 * block_idx + 1}')(x)
|
224 |
+
if lod > current_lod:
|
225 |
+
image = F.avg_pool2d(
|
226 |
+
image, kernel_size=2, stride=2, padding=0)
|
227 |
+
x = self.output(x)
|
228 |
+
|
229 |
+
return {'score': x}
|
230 |
+
|
231 |
+
|
232 |
+
class MiniBatchSTDLayer(nn.Module):
|
233 |
+
"""Implements the minibatch standard deviation layer."""
|
234 |
+
|
235 |
+
def __init__(self, groups, eps):
|
236 |
+
super().__init__()
|
237 |
+
self.groups = groups
|
238 |
+
self.eps = eps
|
239 |
+
|
240 |
+
def extra_repr(self):
|
241 |
+
return f'groups={self.groups}, epsilon={self.eps}'
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
if self.groups <= 1:
|
245 |
+
return x
|
246 |
+
|
247 |
+
N, C, H, W = x.shape
|
248 |
+
G = min(self.groups, N) # Number of groups.
|
249 |
+
|
250 |
+
y = x.reshape(G, -1, C, H, W) # [GnCHW]
|
251 |
+
y = y - y.mean(dim=0) # [GnCHW]
|
252 |
+
y = y.square().mean(dim=0) # [nCHW]
|
253 |
+
y = (y + self.eps).sqrt() # [nCHW]
|
254 |
+
y = y.mean(dim=(1, 2, 3), keepdim=True) # [n111]
|
255 |
+
y = y.repeat(G, 1, H, W) # [N1HW]
|
256 |
+
x = torch.cat([x, y], dim=1) # [N(C+1)HW]
|
257 |
+
|
258 |
+
return x
|
259 |
+
|
260 |
+
|
261 |
+
class DownsamplingLayer(nn.Module):
|
262 |
+
"""Implements the downsampling layer.
|
263 |
+
|
264 |
+
Basically, this layer can be used to downsample feature maps with average
|
265 |
+
pooling.
|
266 |
+
"""
|
267 |
+
|
268 |
+
def __init__(self, scale_factor):
|
269 |
+
super().__init__()
|
270 |
+
self.scale_factor = scale_factor
|
271 |
+
|
272 |
+
def extra_repr(self):
|
273 |
+
return f'factor={self.scale_factor}'
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
if self.scale_factor <= 1:
|
277 |
+
return x
|
278 |
+
return F.avg_pool2d(x,
|
279 |
+
kernel_size=self.scale_factor,
|
280 |
+
stride=self.scale_factor,
|
281 |
+
padding=0)
|
282 |
+
|
283 |
+
|
284 |
+
class ConvLayer(nn.Module):
|
285 |
+
"""Implements the convolutional layer.
|
286 |
+
|
287 |
+
Basically, this layer executes convolution, activation, and downsampling (if
|
288 |
+
needed) in sequence.
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(self,
|
292 |
+
in_channels,
|
293 |
+
out_channels,
|
294 |
+
kernel_size,
|
295 |
+
add_bias,
|
296 |
+
downsample,
|
297 |
+
fused_scale,
|
298 |
+
use_wscale,
|
299 |
+
wscale_gain,
|
300 |
+
activation_type):
|
301 |
+
"""Initializes with layer settings.
|
302 |
+
|
303 |
+
Args:
|
304 |
+
in_channels: Number of channels of the input tensor.
|
305 |
+
out_channels: Number of channels of the output tensor.
|
306 |
+
kernel_size: Size of the convolutional kernels.
|
307 |
+
add_bias: Whether to add bias onto the convolutional result.
|
308 |
+
downsample: Whether to downsample the result after convolution.
|
309 |
+
fused_scale: Whether to fused `conv2d` and `downsample` together,
|
310 |
+
resulting in `conv2d` with strides.
|
311 |
+
use_wscale: Whether to use weight scaling.
|
312 |
+
wscale_gain: Gain factor for weight scaling.
|
313 |
+
activation_type: Type of activation.
|
314 |
+
"""
|
315 |
+
super().__init__()
|
316 |
+
self.in_channels = in_channels
|
317 |
+
self.out_channels = out_channels
|
318 |
+
self.kernel_size = kernel_size
|
319 |
+
self.add_bias = add_bias
|
320 |
+
self.downsample = downsample
|
321 |
+
self.fused_scale = fused_scale
|
322 |
+
self.use_wscale = use_wscale
|
323 |
+
self.wscale_gain = wscale_gain
|
324 |
+
self.activation_type = activation_type
|
325 |
+
|
326 |
+
if downsample and not fused_scale:
|
327 |
+
self.down = DownsamplingLayer(scale_factor=2)
|
328 |
+
else:
|
329 |
+
self.down = nn.Identity()
|
330 |
+
|
331 |
+
if downsample and fused_scale:
|
332 |
+
self.use_stride = True
|
333 |
+
self.stride = 2
|
334 |
+
self.padding = 1
|
335 |
+
else:
|
336 |
+
self.use_stride = False
|
337 |
+
self.stride = 1
|
338 |
+
self.padding = kernel_size // 2
|
339 |
+
|
340 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
341 |
+
fan_in = kernel_size * kernel_size * in_channels
|
342 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
343 |
+
if use_wscale:
|
344 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
345 |
+
self.wscale = wscale
|
346 |
+
else:
|
347 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
348 |
+
self.wscale = 1.0
|
349 |
+
|
350 |
+
if add_bias:
|
351 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
352 |
+
else:
|
353 |
+
self.bias = None
|
354 |
+
|
355 |
+
assert activation_type in ['linear', 'relu', 'lrelu']
|
356 |
+
|
357 |
+
def extra_repr(self):
|
358 |
+
return (f'in_ch={self.in_channels}, '
|
359 |
+
f'out_ch={self.out_channels}, '
|
360 |
+
f'ksize={self.kernel_size}, '
|
361 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
362 |
+
f'bias={self.add_bias}, '
|
363 |
+
f'downsample={self.scale_factor}, '
|
364 |
+
f'fused_scale={self.fused_scale}, '
|
365 |
+
f'act={self.activation_type}')
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
weight = self.weight
|
369 |
+
if self.wscale != 1.0:
|
370 |
+
weight = weight * self.wscale
|
371 |
+
|
372 |
+
if self.use_stride:
|
373 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
374 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
375 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
|
376 |
+
x = F.conv2d(x,
|
377 |
+
weight=weight,
|
378 |
+
bias=self.bias,
|
379 |
+
stride=self.stride,
|
380 |
+
padding=self.padding)
|
381 |
+
|
382 |
+
if self.activation_type == 'linear':
|
383 |
+
pass
|
384 |
+
elif self.activation_type == 'relu':
|
385 |
+
x = F.relu(x, inplace=True)
|
386 |
+
elif self.activation_type == 'lrelu':
|
387 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
|
388 |
+
else:
|
389 |
+
raise NotImplementedError(f'Not implemented activation type '
|
390 |
+
f'`{self.activation_type}`!')
|
391 |
+
x = self.down(x)
|
392 |
+
|
393 |
+
return x
|
394 |
+
|
395 |
+
|
396 |
+
class DenseLayer(nn.Module):
|
397 |
+
"""Implements the dense layer."""
|
398 |
+
|
399 |
+
def __init__(self,
|
400 |
+
in_channels,
|
401 |
+
out_channels,
|
402 |
+
add_bias,
|
403 |
+
use_wscale,
|
404 |
+
wscale_gain,
|
405 |
+
activation_type):
|
406 |
+
"""Initializes with layer settings.
|
407 |
+
|
408 |
+
Args:
|
409 |
+
in_channels: Number of channels of the input tensor.
|
410 |
+
out_channels: Number of channels of the output tensor.
|
411 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
412 |
+
use_wscale: Whether to use weight scaling.
|
413 |
+
wscale_gain: Gain factor for weight scaling.
|
414 |
+
activation_type: Type of activation.
|
415 |
+
|
416 |
+
Raises:
|
417 |
+
NotImplementedError: If the `activation_type` is not supported.
|
418 |
+
"""
|
419 |
+
super().__init__()
|
420 |
+
self.in_channels = in_channels
|
421 |
+
self.out_channels = out_channels
|
422 |
+
self.add_bias = add_bias
|
423 |
+
self.use_wscale = use_wscale
|
424 |
+
self.wscale_gain = wscale_gain
|
425 |
+
self.activation_type = activation_type
|
426 |
+
|
427 |
+
weight_shape = (out_channels, in_channels)
|
428 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
429 |
+
if use_wscale:
|
430 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
431 |
+
self.wscale = wscale
|
432 |
+
else:
|
433 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
434 |
+
self.wscale = 1.0
|
435 |
+
|
436 |
+
if add_bias:
|
437 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
438 |
+
else:
|
439 |
+
self.bias = None
|
440 |
+
|
441 |
+
assert activation_type in ['linear', 'relu', 'lrelu']
|
442 |
+
|
443 |
+
def forward(self, x):
|
444 |
+
if x.ndim != 2:
|
445 |
+
x = x.flatten(start_dim=1)
|
446 |
+
|
447 |
+
weight = self.weight
|
448 |
+
if self.wscale != 1.0:
|
449 |
+
weight = weight * self.wscale
|
450 |
+
|
451 |
+
x = F.linear(x, weight=weight, bias=self.bias)
|
452 |
+
|
453 |
+
if self.activation_type == 'linear':
|
454 |
+
pass
|
455 |
+
elif self.activation_type == 'relu':
|
456 |
+
x = F.relu(x, inplace=True)
|
457 |
+
elif self.activation_type == 'lrelu':
|
458 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
|
459 |
+
else:
|
460 |
+
raise NotImplementedError(f'Not implemented activation type '
|
461 |
+
f'`{self.activation_type}`!')
|
462 |
+
|
463 |
+
return x
|
464 |
+
|
465 |
+
# pylint: enable=missing-function-docstring
|
models/pggan_generator.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of generator described in PGGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1710.10196.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation:
|
7 |
+
https://github.com/tkarras/progressive_growing_of_gans
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
__all__ = ['PGGANGenerator']
|
17 |
+
|
18 |
+
# Resolutions allowed.
|
19 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
20 |
+
|
21 |
+
# pylint: disable=missing-function-docstring
|
22 |
+
|
23 |
+
class PGGANGenerator(nn.Module):
|
24 |
+
"""Defines the generator network in PGGAN.
|
25 |
+
|
26 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
27 |
+
[-1, 1].
|
28 |
+
|
29 |
+
Settings for the network:
|
30 |
+
|
31 |
+
(1) resolution: The resolution of the output image.
|
32 |
+
(2) init_res: The initial resolution to start with convolution. (default: 4)
|
33 |
+
(3) z_dim: Dimension of the input latent space, Z. (default: 512)
|
34 |
+
(4) image_channels: Number of channels of the output image. (default: 3)
|
35 |
+
(5) final_tanh: Whether to use `tanh` to control the final pixel range.
|
36 |
+
(default: False)
|
37 |
+
(6) label_dim: Dimension of the additional label for conditional generation.
|
38 |
+
In one-hot conditioning case, it is equal to the number of classes. If
|
39 |
+
set to 0, conditioning training will be disabled. (default: 0)
|
40 |
+
(7) fused_scale: Whether to fused `upsample` and `conv2d` together,
|
41 |
+
resulting in `conv2d_transpose`. (default: False)
|
42 |
+
(8) use_wscale: Whether to use weight scaling. (default: True)
|
43 |
+
(9) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
|
44 |
+
(10) fmaps_base: Factor to control number of feature maps for each layer.
|
45 |
+
(default: 16 << 10)
|
46 |
+
(11) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
47 |
+
(12) eps: A small value to avoid divide overflow. (default: 1e-8)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
resolution,
|
52 |
+
init_res=4,
|
53 |
+
z_dim=512,
|
54 |
+
image_channels=3,
|
55 |
+
final_tanh=False,
|
56 |
+
label_dim=0,
|
57 |
+
fused_scale=False,
|
58 |
+
use_wscale=True,
|
59 |
+
wscale_gain=np.sqrt(2.0),
|
60 |
+
fmaps_base=16 << 10,
|
61 |
+
fmaps_max=512,
|
62 |
+
eps=1e-8):
|
63 |
+
"""Initializes with basic settings.
|
64 |
+
|
65 |
+
Raises:
|
66 |
+
ValueError: If the `resolution` is not supported.
|
67 |
+
"""
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
71 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
72 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
73 |
+
|
74 |
+
self.init_res = init_res
|
75 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
76 |
+
self.resolution = resolution
|
77 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
78 |
+
self.z_dim = z_dim
|
79 |
+
self.image_channels = image_channels
|
80 |
+
self.final_tanh = final_tanh
|
81 |
+
self.label_dim = label_dim
|
82 |
+
self.fused_scale = fused_scale
|
83 |
+
self.use_wscale = use_wscale
|
84 |
+
self.wscale_gain = wscale_gain
|
85 |
+
self.fmaps_base = fmaps_base
|
86 |
+
self.fmaps_max = fmaps_max
|
87 |
+
self.eps = eps
|
88 |
+
|
89 |
+
# Dimension of latent space, which is convenient for sampling.
|
90 |
+
self.latent_dim = (self.z_dim,)
|
91 |
+
|
92 |
+
# Number of convolutional layers.
|
93 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
94 |
+
|
95 |
+
# Level-of-details (used for progressive training).
|
96 |
+
self.register_buffer('lod', torch.zeros(()))
|
97 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
98 |
+
|
99 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
100 |
+
res = 2 ** res_log2
|
101 |
+
in_channels = self.get_nf(res // 2)
|
102 |
+
out_channels = self.get_nf(res)
|
103 |
+
block_idx = res_log2 - self.init_res_log2
|
104 |
+
|
105 |
+
# First convolution layer for each resolution.
|
106 |
+
if res == self.init_res:
|
107 |
+
self.add_module(
|
108 |
+
f'layer{2 * block_idx}',
|
109 |
+
ConvLayer(in_channels=z_dim + label_dim,
|
110 |
+
out_channels=out_channels,
|
111 |
+
kernel_size=init_res,
|
112 |
+
padding=init_res - 1,
|
113 |
+
add_bias=True,
|
114 |
+
upsample=False,
|
115 |
+
fused_scale=False,
|
116 |
+
use_wscale=use_wscale,
|
117 |
+
wscale_gain=wscale_gain,
|
118 |
+
activation_type='lrelu',
|
119 |
+
eps=eps))
|
120 |
+
tf_layer_name = 'Dense'
|
121 |
+
else:
|
122 |
+
self.add_module(
|
123 |
+
f'layer{2 * block_idx}',
|
124 |
+
ConvLayer(in_channels=in_channels,
|
125 |
+
out_channels=out_channels,
|
126 |
+
kernel_size=3,
|
127 |
+
padding=1,
|
128 |
+
add_bias=True,
|
129 |
+
upsample=True,
|
130 |
+
fused_scale=fused_scale,
|
131 |
+
use_wscale=use_wscale,
|
132 |
+
wscale_gain=wscale_gain,
|
133 |
+
activation_type='lrelu',
|
134 |
+
eps=eps))
|
135 |
+
tf_layer_name = 'Conv0_up' if fused_scale else 'Conv0'
|
136 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
137 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
138 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
139 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
140 |
+
|
141 |
+
# Second convolution layer for each resolution.
|
142 |
+
self.add_module(
|
143 |
+
f'layer{2 * block_idx + 1}',
|
144 |
+
ConvLayer(in_channels=out_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
kernel_size=3,
|
147 |
+
padding=1,
|
148 |
+
add_bias=True,
|
149 |
+
upsample=False,
|
150 |
+
fused_scale=False,
|
151 |
+
use_wscale=use_wscale,
|
152 |
+
wscale_gain=wscale_gain,
|
153 |
+
activation_type='lrelu',
|
154 |
+
eps=eps))
|
155 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
156 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
157 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
158 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
159 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
160 |
+
|
161 |
+
# Output convolution layer for each resolution.
|
162 |
+
self.add_module(
|
163 |
+
f'output{block_idx}',
|
164 |
+
ConvLayer(in_channels=out_channels,
|
165 |
+
out_channels=image_channels,
|
166 |
+
kernel_size=1,
|
167 |
+
padding=0,
|
168 |
+
add_bias=True,
|
169 |
+
upsample=False,
|
170 |
+
fused_scale=False,
|
171 |
+
use_wscale=use_wscale,
|
172 |
+
wscale_gain=1.0,
|
173 |
+
activation_type='linear',
|
174 |
+
eps=eps))
|
175 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
|
176 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
|
177 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
|
178 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
|
179 |
+
|
180 |
+
def get_nf(self, res):
|
181 |
+
"""Gets number of feature maps according to the given resolution."""
|
182 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
183 |
+
|
184 |
+
def forward(self, z, label=None, lod=None):
|
185 |
+
if z.ndim != 2 or z.shape[1] != self.z_dim:
|
186 |
+
raise ValueError(f'Input latent code should be with shape '
|
187 |
+
f'[batch_size, latent_dim], where '
|
188 |
+
f'`latent_dim` equals to {self.z_dim}!\n'
|
189 |
+
f'But `{z.shape}` is received!')
|
190 |
+
z = self.layer0.pixel_norm(z)
|
191 |
+
if self.label_dim:
|
192 |
+
if label is None:
|
193 |
+
raise ValueError(f'Model requires an additional label '
|
194 |
+
f'(with size {self.label_dim}) as input, '
|
195 |
+
f'but no label is received!')
|
196 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
|
197 |
+
raise ValueError(f'Input label should be with shape '
|
198 |
+
f'[batch_size, label_dim], where '
|
199 |
+
f'`batch_size` equals to that of '
|
200 |
+
f'latent codes ({z.shape[0]}) and '
|
201 |
+
f'`label_dim` equals to {self.label_dim}!\n'
|
202 |
+
f'But `{label.shape}` is received!')
|
203 |
+
label = label.to(dtype=torch.float32)
|
204 |
+
z = torch.cat((z, label), dim=1)
|
205 |
+
|
206 |
+
lod = self.lod.item() if lod is None else lod
|
207 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
208 |
+
raise ValueError(f'Maximum level-of-details (lod) is '
|
209 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
210 |
+
f'but `{lod}` is received!')
|
211 |
+
|
212 |
+
x = z.view(z.shape[0], self.z_dim + self.label_dim, 1, 1)
|
213 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
214 |
+
current_lod = self.final_res_log2 - res_log2
|
215 |
+
block_idx = res_log2 - self.init_res_log2
|
216 |
+
if lod < current_lod + 1:
|
217 |
+
x = getattr(self, f'layer{2 * block_idx}')(x)
|
218 |
+
x = getattr(self, f'layer{2 * block_idx + 1}')(x)
|
219 |
+
if current_lod - 1 < lod <= current_lod:
|
220 |
+
image = getattr(self, f'output{block_idx}')(x)
|
221 |
+
elif current_lod < lod < current_lod + 1:
|
222 |
+
alpha = np.ceil(lod) - lod
|
223 |
+
temp = getattr(self, f'output{block_idx}')(x)
|
224 |
+
image = F.interpolate(image, scale_factor=2, mode='nearest')
|
225 |
+
image = temp * alpha + image * (1 - alpha)
|
226 |
+
elif lod >= current_lod + 1:
|
227 |
+
image = F.interpolate(image, scale_factor=2, mode='nearest')
|
228 |
+
if self.final_tanh:
|
229 |
+
image = torch.tanh(image)
|
230 |
+
|
231 |
+
results = {
|
232 |
+
'z': z,
|
233 |
+
'label': label,
|
234 |
+
'image': image,
|
235 |
+
}
|
236 |
+
return results
|
237 |
+
|
238 |
+
|
239 |
+
class PixelNormLayer(nn.Module):
|
240 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
241 |
+
|
242 |
+
def __init__(self, dim, eps):
|
243 |
+
super().__init__()
|
244 |
+
self.dim = dim
|
245 |
+
self.eps = eps
|
246 |
+
|
247 |
+
def extra_repr(self):
|
248 |
+
return f'dim={self.dim}, epsilon={self.eps}'
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
|
252 |
+
return x * scale
|
253 |
+
|
254 |
+
|
255 |
+
class UpsamplingLayer(nn.Module):
|
256 |
+
"""Implements the upsampling layer.
|
257 |
+
|
258 |
+
Basically, this layer can be used to upsample feature maps with nearest
|
259 |
+
neighbor interpolation.
|
260 |
+
"""
|
261 |
+
|
262 |
+
def __init__(self, scale_factor):
|
263 |
+
super().__init__()
|
264 |
+
self.scale_factor = scale_factor
|
265 |
+
|
266 |
+
def extra_repr(self):
|
267 |
+
return f'factor={self.scale_factor}'
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
if self.scale_factor <= 1:
|
271 |
+
return x
|
272 |
+
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
|
273 |
+
|
274 |
+
|
275 |
+
class ConvLayer(nn.Module):
|
276 |
+
"""Implements the convolutional layer.
|
277 |
+
|
278 |
+
Basically, this layer executes pixel-wise normalization, upsampling (if
|
279 |
+
needed), convolution, and activation in sequence.
|
280 |
+
"""
|
281 |
+
|
282 |
+
def __init__(self,
|
283 |
+
in_channels,
|
284 |
+
out_channels,
|
285 |
+
kernel_size,
|
286 |
+
padding,
|
287 |
+
add_bias,
|
288 |
+
upsample,
|
289 |
+
fused_scale,
|
290 |
+
use_wscale,
|
291 |
+
wscale_gain,
|
292 |
+
activation_type,
|
293 |
+
eps):
|
294 |
+
"""Initializes with layer settings.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
in_channels: Number of channels of the input tensor.
|
298 |
+
out_channels: Number of channels of the output tensor.
|
299 |
+
kernel_size: Size of the convolutional kernels.
|
300 |
+
padding: Padding used in convolution.
|
301 |
+
add_bias: Whether to add bias onto the convolutional result.
|
302 |
+
upsample: Whether to upsample the input tensor before convolution.
|
303 |
+
fused_scale: Whether to fused `upsample` and `conv2d` together,
|
304 |
+
resulting in `conv2d_transpose`.
|
305 |
+
use_wscale: Whether to use weight scaling.
|
306 |
+
wscale_gain: Gain factor for weight scaling.
|
307 |
+
activation_type: Type of activation.
|
308 |
+
eps: A small value to avoid divide overflow.
|
309 |
+
"""
|
310 |
+
super().__init__()
|
311 |
+
self.in_channels = in_channels
|
312 |
+
self.out_channels = out_channels
|
313 |
+
self.kernel_size = kernel_size
|
314 |
+
self.padding = padding
|
315 |
+
self.add_bias = add_bias
|
316 |
+
self.upsample = upsample
|
317 |
+
self.fused_scale = fused_scale
|
318 |
+
self.use_wscale = use_wscale
|
319 |
+
self.wscale_gain = wscale_gain
|
320 |
+
self.activation_type = activation_type
|
321 |
+
self.eps = eps
|
322 |
+
|
323 |
+
self.pixel_norm = PixelNormLayer(dim=1, eps=eps)
|
324 |
+
|
325 |
+
if upsample and not fused_scale:
|
326 |
+
self.up = UpsamplingLayer(scale_factor=2)
|
327 |
+
else:
|
328 |
+
self.up = nn.Identity()
|
329 |
+
|
330 |
+
if upsample and fused_scale:
|
331 |
+
self.use_conv2d_transpose = True
|
332 |
+
weight_shape = (in_channels, out_channels, kernel_size, kernel_size)
|
333 |
+
self.stride = 2
|
334 |
+
self.padding = 1
|
335 |
+
else:
|
336 |
+
self.use_conv2d_transpose = False
|
337 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
338 |
+
self.stride = 1
|
339 |
+
|
340 |
+
fan_in = kernel_size * kernel_size * in_channels
|
341 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
342 |
+
if use_wscale:
|
343 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
344 |
+
self.wscale = wscale
|
345 |
+
else:
|
346 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
347 |
+
self.wscale = 1.0
|
348 |
+
|
349 |
+
if add_bias:
|
350 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
351 |
+
else:
|
352 |
+
self.bias = None
|
353 |
+
|
354 |
+
assert activation_type in ['linear', 'relu', 'lrelu']
|
355 |
+
|
356 |
+
def extra_repr(self):
|
357 |
+
return (f'in_ch={self.in_channels}, '
|
358 |
+
f'out_ch={self.out_channels}, '
|
359 |
+
f'ksize={self.kernel_size}, '
|
360 |
+
f'padding={self.padding}, '
|
361 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
362 |
+
f'bias={self.add_bias}, '
|
363 |
+
f'upsample={self.scale_factor}, '
|
364 |
+
f'fused_scale={self.fused_scale}, '
|
365 |
+
f'act={self.activation_type}')
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
x = self.pixel_norm(x)
|
369 |
+
x = self.up(x)
|
370 |
+
weight = self.weight
|
371 |
+
if self.wscale != 1.0:
|
372 |
+
weight = weight * self.wscale
|
373 |
+
if self.use_conv2d_transpose:
|
374 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
375 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
376 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
|
377 |
+
x = F.conv_transpose2d(x,
|
378 |
+
weight=weight,
|
379 |
+
bias=self.bias,
|
380 |
+
stride=self.stride,
|
381 |
+
padding=self.padding)
|
382 |
+
else:
|
383 |
+
x = F.conv2d(x,
|
384 |
+
weight=weight,
|
385 |
+
bias=self.bias,
|
386 |
+
stride=self.stride,
|
387 |
+
padding=self.padding)
|
388 |
+
|
389 |
+
if self.activation_type == 'linear':
|
390 |
+
pass
|
391 |
+
elif self.activation_type == 'relu':
|
392 |
+
x = F.relu(x, inplace=True)
|
393 |
+
elif self.activation_type == 'lrelu':
|
394 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
|
395 |
+
else:
|
396 |
+
raise NotImplementedError(f'Not implemented activation type '
|
397 |
+
f'`{self.activation_type}`!')
|
398 |
+
|
399 |
+
return x
|
400 |
+
|
401 |
+
# pylint: enable=missing-function-docstring
|
models/stylegan2_discriminator.py
ADDED
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of discriminator described in StyleGAN2.
|
3 |
+
|
4 |
+
Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip
|
5 |
+
connections, increases model size and disables progressive growth. This script
|
6 |
+
ONLY supports config F in the original paper.
|
7 |
+
|
8 |
+
Paper: https://arxiv.org/pdf/1912.04958.pdf
|
9 |
+
|
10 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
11 |
+
"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
from third_party.stylegan2_official_ops import bias_act
|
19 |
+
from third_party.stylegan2_official_ops import upfirdn2d
|
20 |
+
from third_party.stylegan2_official_ops import conv2d_gradfix
|
21 |
+
|
22 |
+
__all__ = ['StyleGAN2Discriminator']
|
23 |
+
|
24 |
+
# Resolutions allowed.
|
25 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
26 |
+
|
27 |
+
# Architectures allowed.
|
28 |
+
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
29 |
+
|
30 |
+
# pylint: disable=missing-function-docstring
|
31 |
+
|
32 |
+
class StyleGAN2Discriminator(nn.Module):
|
33 |
+
"""Defines the discriminator network in StyleGAN2.
|
34 |
+
|
35 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
36 |
+
range [-1, 1] as inputs.
|
37 |
+
|
38 |
+
Settings for the backbone:
|
39 |
+
|
40 |
+
(1) resolution: The resolution of the input image. (default: -1)
|
41 |
+
(2) init_res: Smallest resolution of the convolutional backbone.
|
42 |
+
(default: 4)
|
43 |
+
(3) image_channels: Number of channels of the input image. (default: 3)
|
44 |
+
(4) architecture: Type of architecture. Support `origin`, `skip`, and
|
45 |
+
`resnet`. (default: `resnet`)
|
46 |
+
(5) use_wscale: Whether to use weight scaling. (default: True)
|
47 |
+
(6) wscale_gain: The factor to control weight scaling. (default: 1.0)
|
48 |
+
(7) lr_mul: Learning rate multiplier for backbone. (default: 1.0)
|
49 |
+
(8) mbstd_groups: Group size for the minibatch standard deviation layer.
|
50 |
+
`0` means disable. (default: 4)
|
51 |
+
(9) mbstd_channels: Number of new channels (appended to the original feature
|
52 |
+
map) after the minibatch standard deviation layer. (default: 1)
|
53 |
+
(10) fmaps_base: Factor to control number of feature maps for each layer.
|
54 |
+
(default: 32 << 10)
|
55 |
+
(11) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
56 |
+
(12) filter_kernel: Kernel used for filtering (e.g., downsampling).
|
57 |
+
(default: (1, 3, 3, 1))
|
58 |
+
(13) conv_clamp: A threshold to clamp the output of convolution layers to
|
59 |
+
avoid overflow under FP16 training. (default: None)
|
60 |
+
(14) eps: A small value to avoid divide overflow. (default: 1e-8)
|
61 |
+
|
62 |
+
Settings for conditional model:
|
63 |
+
|
64 |
+
(1) label_dim: Dimension of the additional label for conditional generation.
|
65 |
+
In one-hot conditioning case, it is equal to the number of classes. If
|
66 |
+
set to 0, conditioning training will be disabled. (default: 0)
|
67 |
+
(2) embedding_dim: Dimension of the embedding space, if needed.
|
68 |
+
(default: 512)
|
69 |
+
(3) embedding_bias: Whether to add bias to embedding learning.
|
70 |
+
(default: True)
|
71 |
+
(4) embedding_use_wscale: Whether to use weight scaling for embedding
|
72 |
+
learning. (default: True)
|
73 |
+
(5) embedding_lr_mul: Learning rate multiplier for the embedding learning.
|
74 |
+
(default: 1.0)
|
75 |
+
(6) normalize_embedding: Whether to normalize the embedding. (default: True)
|
76 |
+
(7) mapping_layers: Number of layers of the additional mapping network after
|
77 |
+
embedding. (default: 0)
|
78 |
+
(8) mapping_fmaps: Number of hidden channels of the additional mapping
|
79 |
+
network after embedding. (default: 512)
|
80 |
+
(9) mapping_use_wscale: Whether to use weight scaling for the additional
|
81 |
+
mapping network. (default: True)
|
82 |
+
(10) mapping_lr_mul: Learning rate multiplier for the additional mapping
|
83 |
+
network after embedding. (default: 0.1)
|
84 |
+
|
85 |
+
Runtime settings:
|
86 |
+
|
87 |
+
(1) fp16_res: Layers at resolution higher than (or equal to) this field will
|
88 |
+
use `float16` precision for computation. This is merely used for
|
89 |
+
acceleration. If set as `None`, all layers will use `float32` by
|
90 |
+
default. (default: None)
|
91 |
+
(2) impl: Implementation mode of some particular ops, e.g., `filtering`,
|
92 |
+
`bias_act`, etc. `cuda` means using the official CUDA implementation
|
93 |
+
from StyleGAN2, while `ref` means using the native PyTorch ops.
|
94 |
+
(default: `cuda`)
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self,
|
98 |
+
# Settings for backbone.
|
99 |
+
resolution=-1,
|
100 |
+
init_res=4,
|
101 |
+
image_channels=3,
|
102 |
+
architecture='resnet',
|
103 |
+
use_wscale=True,
|
104 |
+
wscale_gain=1.0,
|
105 |
+
lr_mul=1.0,
|
106 |
+
mbstd_groups=4,
|
107 |
+
mbstd_channels=1,
|
108 |
+
fmaps_base=32 << 10,
|
109 |
+
fmaps_max=512,
|
110 |
+
filter_kernel=(1, 3, 3, 1),
|
111 |
+
conv_clamp=None,
|
112 |
+
eps=1e-8,
|
113 |
+
# Settings for conditional model.
|
114 |
+
label_dim=0,
|
115 |
+
embedding_dim=512,
|
116 |
+
embedding_bias=True,
|
117 |
+
embedding_use_wscale=True,
|
118 |
+
embedding_lr_mul=1.0,
|
119 |
+
normalize_embedding=True,
|
120 |
+
mapping_layers=0,
|
121 |
+
mapping_fmaps=512,
|
122 |
+
mapping_use_wscale=True,
|
123 |
+
mapping_lr_mul=0.1):
|
124 |
+
"""Initializes with basic settings.
|
125 |
+
|
126 |
+
Raises:
|
127 |
+
ValueError: If the `resolution` is not supported, or `architecture`
|
128 |
+
is not supported.
|
129 |
+
"""
|
130 |
+
super().__init__()
|
131 |
+
|
132 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
133 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
134 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
135 |
+
architecture = architecture.lower()
|
136 |
+
if architecture not in _ARCHITECTURES_ALLOWED:
|
137 |
+
raise ValueError(f'Invalid architecture: `{architecture}`!\n'
|
138 |
+
f'Architectures allowed: '
|
139 |
+
f'{_ARCHITECTURES_ALLOWED}.')
|
140 |
+
|
141 |
+
self.init_res = init_res
|
142 |
+
self.init_res_log2 = int(np.log2(init_res))
|
143 |
+
self.resolution = resolution
|
144 |
+
self.final_res_log2 = int(np.log2(resolution))
|
145 |
+
self.image_channels = image_channels
|
146 |
+
self.architecture = architecture
|
147 |
+
self.use_wscale = use_wscale
|
148 |
+
self.wscale_gain = wscale_gain
|
149 |
+
self.lr_mul = lr_mul
|
150 |
+
self.mbstd_groups = mbstd_groups
|
151 |
+
self.mbstd_channels = mbstd_channels
|
152 |
+
self.fmaps_base = fmaps_base
|
153 |
+
self.fmaps_max = fmaps_max
|
154 |
+
self.filter_kernel = filter_kernel
|
155 |
+
self.conv_clamp = conv_clamp
|
156 |
+
self.eps = eps
|
157 |
+
|
158 |
+
self.label_dim = label_dim
|
159 |
+
self.embedding_dim = embedding_dim
|
160 |
+
self.embedding_bias = embedding_bias
|
161 |
+
self.embedding_use_wscale = embedding_use_wscale
|
162 |
+
self.embedding_lr_mul = embedding_lr_mul
|
163 |
+
self.normalize_embedding = normalize_embedding
|
164 |
+
self.mapping_layers = mapping_layers
|
165 |
+
self.mapping_fmaps = mapping_fmaps
|
166 |
+
self.mapping_use_wscale = mapping_use_wscale
|
167 |
+
self.mapping_lr_mul = mapping_lr_mul
|
168 |
+
|
169 |
+
self.pth_to_tf_var_mapping = {}
|
170 |
+
|
171 |
+
# Embedding for conditional discrimination.
|
172 |
+
self.use_embedding = label_dim > 0 and embedding_dim > 0
|
173 |
+
if self.use_embedding:
|
174 |
+
self.embedding = DenseLayer(in_channels=label_dim,
|
175 |
+
out_channels=embedding_dim,
|
176 |
+
add_bias=embedding_bias,
|
177 |
+
init_bias=0.0,
|
178 |
+
use_wscale=embedding_use_wscale,
|
179 |
+
wscale_gain=wscale_gain,
|
180 |
+
lr_mul=embedding_lr_mul,
|
181 |
+
activation_type='linear')
|
182 |
+
self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight'
|
183 |
+
if self.embedding_bias:
|
184 |
+
self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias'
|
185 |
+
|
186 |
+
if self.normalize_embedding:
|
187 |
+
self.norm = PixelNormLayer(dim=1, eps=eps)
|
188 |
+
|
189 |
+
for i in range(mapping_layers):
|
190 |
+
in_channels = (embedding_dim if i == 0 else mapping_fmaps)
|
191 |
+
out_channels = (embedding_dim if i == (mapping_layers - 1) else
|
192 |
+
mapping_fmaps)
|
193 |
+
layer_name = f'mapping{i}'
|
194 |
+
self.add_module(layer_name,
|
195 |
+
DenseLayer(in_channels=in_channels,
|
196 |
+
out_channels=out_channels,
|
197 |
+
add_bias=True,
|
198 |
+
init_bias=0.0,
|
199 |
+
use_wscale=mapping_use_wscale,
|
200 |
+
wscale_gain=wscale_gain,
|
201 |
+
lr_mul=mapping_lr_mul,
|
202 |
+
activation_type='lrelu'))
|
203 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
204 |
+
f'Mapping{i}/weight')
|
205 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
206 |
+
f'Mapping{i}/bias')
|
207 |
+
|
208 |
+
# Convolutional backbone.
|
209 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
210 |
+
res = 2 ** res_log2
|
211 |
+
in_channels = self.get_nf(res)
|
212 |
+
out_channels = self.get_nf(res // 2)
|
213 |
+
block_idx = self.final_res_log2 - res_log2
|
214 |
+
|
215 |
+
# Input convolution layer for each resolution (if needed).
|
216 |
+
if res_log2 == self.final_res_log2 or self.architecture == 'skip':
|
217 |
+
layer_name = f'input{block_idx}'
|
218 |
+
self.add_module(layer_name,
|
219 |
+
ConvLayer(in_channels=image_channels,
|
220 |
+
out_channels=in_channels,
|
221 |
+
kernel_size=1,
|
222 |
+
add_bias=True,
|
223 |
+
scale_factor=1,
|
224 |
+
filter_kernel=None,
|
225 |
+
use_wscale=use_wscale,
|
226 |
+
wscale_gain=wscale_gain,
|
227 |
+
lr_mul=lr_mul,
|
228 |
+
activation_type='lrelu',
|
229 |
+
conv_clamp=conv_clamp))
|
230 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
231 |
+
f'{res}x{res}/FromRGB/weight')
|
232 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
233 |
+
f'{res}x{res}/FromRGB/bias')
|
234 |
+
|
235 |
+
# Convolution block for each resolution (except the last one).
|
236 |
+
if res != self.init_res:
|
237 |
+
# First layer (kernel 3x3) without downsampling.
|
238 |
+
layer_name = f'layer{2 * block_idx}'
|
239 |
+
self.add_module(layer_name,
|
240 |
+
ConvLayer(in_channels=in_channels,
|
241 |
+
out_channels=in_channels,
|
242 |
+
kernel_size=3,
|
243 |
+
add_bias=True,
|
244 |
+
scale_factor=1,
|
245 |
+
filter_kernel=None,
|
246 |
+
use_wscale=use_wscale,
|
247 |
+
wscale_gain=wscale_gain,
|
248 |
+
lr_mul=lr_mul,
|
249 |
+
activation_type='lrelu',
|
250 |
+
conv_clamp=conv_clamp))
|
251 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
252 |
+
f'{res}x{res}/Conv0/weight')
|
253 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
254 |
+
f'{res}x{res}/Conv0/bias')
|
255 |
+
|
256 |
+
# Second layer (kernel 3x3) with downsampling
|
257 |
+
layer_name = f'layer{2 * block_idx + 1}'
|
258 |
+
self.add_module(layer_name,
|
259 |
+
ConvLayer(in_channels=in_channels,
|
260 |
+
out_channels=out_channels,
|
261 |
+
kernel_size=3,
|
262 |
+
add_bias=True,
|
263 |
+
scale_factor=2,
|
264 |
+
filter_kernel=filter_kernel,
|
265 |
+
use_wscale=use_wscale,
|
266 |
+
wscale_gain=wscale_gain,
|
267 |
+
lr_mul=lr_mul,
|
268 |
+
activation_type='lrelu',
|
269 |
+
conv_clamp=conv_clamp))
|
270 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
271 |
+
f'{res}x{res}/Conv1_down/weight')
|
272 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
273 |
+
f'{res}x{res}/Conv1_down/bias')
|
274 |
+
|
275 |
+
# Residual branch (kernel 1x1) with downsampling, without bias,
|
276 |
+
# with linear activation.
|
277 |
+
if self.architecture == 'resnet':
|
278 |
+
layer_name = f'residual{block_idx}'
|
279 |
+
self.add_module(layer_name,
|
280 |
+
ConvLayer(in_channels=in_channels,
|
281 |
+
out_channels=out_channels,
|
282 |
+
kernel_size=1,
|
283 |
+
add_bias=False,
|
284 |
+
scale_factor=2,
|
285 |
+
filter_kernel=filter_kernel,
|
286 |
+
use_wscale=use_wscale,
|
287 |
+
wscale_gain=wscale_gain,
|
288 |
+
lr_mul=lr_mul,
|
289 |
+
activation_type='linear',
|
290 |
+
conv_clamp=None))
|
291 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
292 |
+
f'{res}x{res}/Skip/weight')
|
293 |
+
|
294 |
+
# Convolution block for last resolution.
|
295 |
+
else:
|
296 |
+
self.mbstd = MiniBatchSTDLayer(
|
297 |
+
groups=mbstd_groups, new_channels=mbstd_channels, eps=eps)
|
298 |
+
|
299 |
+
# First layer (kernel 3x3) without downsampling.
|
300 |
+
layer_name = f'layer{2 * block_idx}'
|
301 |
+
self.add_module(
|
302 |
+
layer_name,
|
303 |
+
ConvLayer(in_channels=in_channels + mbstd_channels,
|
304 |
+
out_channels=in_channels,
|
305 |
+
kernel_size=3,
|
306 |
+
add_bias=True,
|
307 |
+
scale_factor=1,
|
308 |
+
filter_kernel=None,
|
309 |
+
use_wscale=use_wscale,
|
310 |
+
wscale_gain=wscale_gain,
|
311 |
+
lr_mul=lr_mul,
|
312 |
+
activation_type='lrelu',
|
313 |
+
conv_clamp=conv_clamp))
|
314 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
315 |
+
f'{res}x{res}/Conv/weight')
|
316 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
317 |
+
f'{res}x{res}/Conv/bias')
|
318 |
+
|
319 |
+
# Second layer, as a fully-connected layer.
|
320 |
+
layer_name = f'layer{2 * block_idx + 1}'
|
321 |
+
self.add_module(layer_name,
|
322 |
+
DenseLayer(in_channels=in_channels * res * res,
|
323 |
+
out_channels=in_channels,
|
324 |
+
add_bias=True,
|
325 |
+
init_bias=0.0,
|
326 |
+
use_wscale=use_wscale,
|
327 |
+
wscale_gain=wscale_gain,
|
328 |
+
lr_mul=lr_mul,
|
329 |
+
activation_type='lrelu'))
|
330 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
331 |
+
f'{res}x{res}/Dense0/weight')
|
332 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
333 |
+
f'{res}x{res}/Dense0/bias')
|
334 |
+
|
335 |
+
# Final dense layer to output score.
|
336 |
+
self.output = DenseLayer(in_channels=in_channels,
|
337 |
+
out_channels=(embedding_dim
|
338 |
+
if self.use_embedding
|
339 |
+
else max(label_dim, 1)),
|
340 |
+
add_bias=True,
|
341 |
+
init_bias=0.0,
|
342 |
+
use_wscale=use_wscale,
|
343 |
+
wscale_gain=wscale_gain,
|
344 |
+
lr_mul=lr_mul,
|
345 |
+
activation_type='linear')
|
346 |
+
self.pth_to_tf_var_mapping['output.weight'] = 'Output/weight'
|
347 |
+
self.pth_to_tf_var_mapping['output.bias'] = 'Output/bias'
|
348 |
+
|
349 |
+
# Used for downsampling input image for `skip` architecture.
|
350 |
+
if self.architecture == 'skip':
|
351 |
+
self.register_buffer(
|
352 |
+
'filter', upfirdn2d.setup_filter(filter_kernel))
|
353 |
+
|
354 |
+
def get_nf(self, res):
|
355 |
+
"""Gets number of feature maps according to the given resolution."""
|
356 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
357 |
+
|
358 |
+
def forward(self, image, label=None, fp16_res=None, impl='cuda'):
|
359 |
+
# Check shape.
|
360 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
361 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
362 |
+
raise ValueError(f'The input tensor should be with shape '
|
363 |
+
f'[batch_size, channel, height, width], where '
|
364 |
+
f'`channel` equals to {self.image_channels}, '
|
365 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
366 |
+
f'But `{image.shape}` is received!')
|
367 |
+
if self.label_dim > 0:
|
368 |
+
if label is None:
|
369 |
+
raise ValueError(f'Model requires an additional label '
|
370 |
+
f'(with dimension {self.label_dim}) as input, '
|
371 |
+
f'but no label is received!')
|
372 |
+
batch_size = image.shape[0]
|
373 |
+
if label.ndim != 2 or label.shape != (batch_size, self.label_dim):
|
374 |
+
raise ValueError(f'Input label should be with shape '
|
375 |
+
f'[batch_size, label_dim], where '
|
376 |
+
f'`batch_size` equals to that of '
|
377 |
+
f'images ({image.shape[0]}) and '
|
378 |
+
f'`label_dim` equals to {self.label_dim}!\n'
|
379 |
+
f'But `{label.shape}` is received!')
|
380 |
+
label = label.to(dtype=torch.float32)
|
381 |
+
if self.use_embedding:
|
382 |
+
embed = self.embedding(label, impl=impl)
|
383 |
+
if self.normalize_embedding:
|
384 |
+
embed = self.norm(embed)
|
385 |
+
for i in range(self.mapping_layers):
|
386 |
+
embed = getattr(self, f'mapping{i}')(embed, impl=impl)
|
387 |
+
|
388 |
+
# Cast to `torch.float16` if needed.
|
389 |
+
if fp16_res is not None and self.resolution >= fp16_res:
|
390 |
+
image = image.to(torch.float16)
|
391 |
+
|
392 |
+
x = self.input0(image, impl=impl)
|
393 |
+
|
394 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2, -1):
|
395 |
+
res = 2 ** res_log2
|
396 |
+
# Cast to `torch.float16` if needed.
|
397 |
+
if fp16_res is not None and res >= fp16_res:
|
398 |
+
x = x.to(torch.float16)
|
399 |
+
else:
|
400 |
+
x = x.to(torch.float32)
|
401 |
+
|
402 |
+
idx = self.final_res_log2 - res_log2 # Block index
|
403 |
+
|
404 |
+
if self.architecture == 'skip' and idx > 0:
|
405 |
+
image = upfirdn2d.downsample2d(image, self.filter, impl=impl)
|
406 |
+
# Cast to `torch.float16` if needed.
|
407 |
+
if fp16_res is not None and res >= fp16_res:
|
408 |
+
image = image.to(torch.float16)
|
409 |
+
else:
|
410 |
+
image = image.to(torch.float32)
|
411 |
+
y = getattr(self, f'input{idx}')(image, impl=impl)
|
412 |
+
x = x + y
|
413 |
+
|
414 |
+
if self.architecture == 'resnet':
|
415 |
+
residual = getattr(self, f'residual{idx}')(
|
416 |
+
x, runtime_gain=np.sqrt(0.5), impl=impl)
|
417 |
+
x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
|
418 |
+
x = getattr(self, f'layer{2 * idx + 1}')(
|
419 |
+
x, runtime_gain=np.sqrt(0.5), impl=impl)
|
420 |
+
x = x + residual
|
421 |
+
else:
|
422 |
+
x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
|
423 |
+
x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl)
|
424 |
+
|
425 |
+
# Final output.
|
426 |
+
idx += 1
|
427 |
+
if fp16_res is not None: # Always use FP32 for the last block.
|
428 |
+
x = x.to(torch.float32)
|
429 |
+
if self.architecture == 'skip':
|
430 |
+
image = upfirdn2d.downsample2d(image, self.filter, impl=impl)
|
431 |
+
if fp16_res is not None: # Always use FP32 for the last block.
|
432 |
+
image = image.to(torch.float32)
|
433 |
+
y = getattr(self, f'input{idx}')(image, impl=impl)
|
434 |
+
x = x + y
|
435 |
+
x = self.mbstd(x)
|
436 |
+
x = getattr(self, f'layer{2 * idx}')(x, impl=impl)
|
437 |
+
x = getattr(self, f'layer{2 * idx + 1}')(x, impl=impl)
|
438 |
+
x = self.output(x, impl=impl)
|
439 |
+
|
440 |
+
if self.use_embedding:
|
441 |
+
x = (x * embed).sum(dim=1, keepdim=True)
|
442 |
+
x = x / np.sqrt(self.embedding_dim)
|
443 |
+
elif self.label_dim > 0:
|
444 |
+
x = (x * label).sum(dim=1, keepdim=True)
|
445 |
+
|
446 |
+
results = {
|
447 |
+
'score': x,
|
448 |
+
'label': label
|
449 |
+
}
|
450 |
+
if self.use_embedding:
|
451 |
+
results['embedding'] = embed
|
452 |
+
return results
|
453 |
+
|
454 |
+
|
455 |
+
class PixelNormLayer(nn.Module):
|
456 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
457 |
+
|
458 |
+
def __init__(self, dim, eps):
|
459 |
+
super().__init__()
|
460 |
+
self.dim = dim
|
461 |
+
self.eps = eps
|
462 |
+
|
463 |
+
def extra_repr(self):
|
464 |
+
return f'dim={self.dim}, epsilon={self.eps}'
|
465 |
+
|
466 |
+
def forward(self, x):
|
467 |
+
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
|
468 |
+
return x * scale
|
469 |
+
|
470 |
+
|
471 |
+
class MiniBatchSTDLayer(nn.Module):
|
472 |
+
"""Implements the minibatch standard deviation layer."""
|
473 |
+
|
474 |
+
def __init__(self, groups, new_channels, eps):
|
475 |
+
super().__init__()
|
476 |
+
self.groups = groups
|
477 |
+
self.new_channels = new_channels
|
478 |
+
self.eps = eps
|
479 |
+
|
480 |
+
def extra_repr(self):
|
481 |
+
return (f'groups={self.groups}, '
|
482 |
+
f'new_channels={self.new_channels}, '
|
483 |
+
f'epsilon={self.eps}')
|
484 |
+
|
485 |
+
def forward(self, x):
|
486 |
+
if self.groups <= 1 or self.new_channels < 1:
|
487 |
+
return x
|
488 |
+
|
489 |
+
dtype = x.dtype
|
490 |
+
|
491 |
+
N, C, H, W = x.shape
|
492 |
+
G = min(self.groups, N) # Number of groups.
|
493 |
+
nC = self.new_channels # Number of channel groups.
|
494 |
+
c = C // nC # Channels per channel group.
|
495 |
+
|
496 |
+
y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW]
|
497 |
+
y = y - y.mean(dim=0) # [GnFcHW]
|
498 |
+
y = y.square().mean(dim=0) # [nFcHW]
|
499 |
+
y = (y + self.eps).sqrt() # [nFcHW]
|
500 |
+
y = y.mean(dim=(2, 3, 4)) # [nF]
|
501 |
+
y = y.reshape(-1, nC, 1, 1) # [nF11]
|
502 |
+
y = y.repeat(G, 1, H, W) # [NFHW]
|
503 |
+
x = torch.cat((x, y), dim=1) # [N(C+F)HW]
|
504 |
+
|
505 |
+
assert x.dtype == dtype
|
506 |
+
return x
|
507 |
+
|
508 |
+
|
509 |
+
class ConvLayer(nn.Module):
|
510 |
+
"""Implements the convolutional layer.
|
511 |
+
|
512 |
+
If downsampling is needed (i.e., `scale_factor = 2`), the feature map will
|
513 |
+
be filtered with `filter_kernel` first.
|
514 |
+
"""
|
515 |
+
|
516 |
+
def __init__(self,
|
517 |
+
in_channels,
|
518 |
+
out_channels,
|
519 |
+
kernel_size,
|
520 |
+
add_bias,
|
521 |
+
scale_factor,
|
522 |
+
filter_kernel,
|
523 |
+
use_wscale,
|
524 |
+
wscale_gain,
|
525 |
+
lr_mul,
|
526 |
+
activation_type,
|
527 |
+
conv_clamp):
|
528 |
+
"""Initializes with layer settings.
|
529 |
+
|
530 |
+
Args:
|
531 |
+
in_channels: Number of channels of the input tensor.
|
532 |
+
out_channels: Number of channels of the output tensor.
|
533 |
+
kernel_size: Size of the convolutional kernels.
|
534 |
+
add_bias: Whether to add bias onto the convolutional result.
|
535 |
+
scale_factor: Scale factor for downsampling. `1` means skip
|
536 |
+
downsampling.
|
537 |
+
filter_kernel: Kernel used for filtering.
|
538 |
+
use_wscale: Whether to use weight scaling.
|
539 |
+
wscale_gain: Gain factor for weight scaling.
|
540 |
+
lr_mul: Learning multiplier for both weight and bias.
|
541 |
+
activation_type: Type of activation.
|
542 |
+
conv_clamp: A threshold to clamp the output of convolution layers to
|
543 |
+
avoid overflow under FP16 training.
|
544 |
+
"""
|
545 |
+
super().__init__()
|
546 |
+
self.in_channels = in_channels
|
547 |
+
self.out_channels = out_channels
|
548 |
+
self.kernel_size = kernel_size
|
549 |
+
self.add_bias = add_bias
|
550 |
+
self.scale_factor = scale_factor
|
551 |
+
self.filter_kernel = filter_kernel
|
552 |
+
self.use_wscale = use_wscale
|
553 |
+
self.wscale_gain = wscale_gain
|
554 |
+
self.lr_mul = lr_mul
|
555 |
+
self.activation_type = activation_type
|
556 |
+
self.conv_clamp = conv_clamp
|
557 |
+
|
558 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
559 |
+
fan_in = kernel_size * kernel_size * in_channels
|
560 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
561 |
+
if use_wscale:
|
562 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
563 |
+
self.wscale = wscale * lr_mul
|
564 |
+
else:
|
565 |
+
self.weight = nn.Parameter(
|
566 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
567 |
+
self.wscale = lr_mul
|
568 |
+
|
569 |
+
if add_bias:
|
570 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
571 |
+
self.bscale = lr_mul
|
572 |
+
else:
|
573 |
+
self.bias = None
|
574 |
+
self.act_gain = bias_act.activation_funcs[activation_type].def_gain
|
575 |
+
|
576 |
+
if scale_factor > 1:
|
577 |
+
assert filter_kernel is not None
|
578 |
+
self.register_buffer(
|
579 |
+
'filter', upfirdn2d.setup_filter(filter_kernel))
|
580 |
+
fh, fw = self.filter.shape
|
581 |
+
self.filter_padding = (
|
582 |
+
kernel_size // 2 + (fw - scale_factor + 1) // 2,
|
583 |
+
kernel_size // 2 + (fw - scale_factor) // 2,
|
584 |
+
kernel_size // 2 + (fh - scale_factor + 1) // 2,
|
585 |
+
kernel_size // 2 + (fh - scale_factor) // 2)
|
586 |
+
|
587 |
+
def extra_repr(self):
|
588 |
+
return (f'in_ch={self.in_channels}, '
|
589 |
+
f'out_ch={self.out_channels}, '
|
590 |
+
f'ksize={self.kernel_size}, '
|
591 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
592 |
+
f'bias={self.add_bias}, '
|
593 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
594 |
+
f'downsample={self.scale_factor}, '
|
595 |
+
f'downsample_filter={self.filter_kernel}, '
|
596 |
+
f'act={self.activation_type}, '
|
597 |
+
f'clamp={self.conv_clamp}')
|
598 |
+
|
599 |
+
def forward(self, x, runtime_gain=1.0, impl='cuda'):
|
600 |
+
dtype = x.dtype
|
601 |
+
|
602 |
+
weight = self.weight
|
603 |
+
if self.wscale != 1.0:
|
604 |
+
weight = weight * self.wscale
|
605 |
+
bias = None
|
606 |
+
if self.bias is not None:
|
607 |
+
bias = self.bias.to(dtype)
|
608 |
+
if self.bscale != 1.0:
|
609 |
+
bias = bias * self.bscale
|
610 |
+
|
611 |
+
if self.scale_factor == 1: # Native convolution without downsampling.
|
612 |
+
padding = self.kernel_size // 2
|
613 |
+
x = conv2d_gradfix.conv2d(
|
614 |
+
x, weight.to(dtype), stride=1, padding=padding, impl=impl)
|
615 |
+
else: # Convolution with downsampling.
|
616 |
+
down = self.scale_factor
|
617 |
+
f = self.filter
|
618 |
+
padding = self.filter_padding
|
619 |
+
# When kernel size = 1, use filtering function for downsampling.
|
620 |
+
if self.kernel_size == 1:
|
621 |
+
x = upfirdn2d.upfirdn2d(
|
622 |
+
x, f, down=down, padding=padding, impl=impl)
|
623 |
+
x = conv2d_gradfix.conv2d(
|
624 |
+
x, weight.to(dtype), stride=1, padding=0, impl=impl)
|
625 |
+
# When kernel size != 1, use stride convolution for downsampling.
|
626 |
+
else:
|
627 |
+
x = upfirdn2d.upfirdn2d(
|
628 |
+
x, f, down=1, padding=padding, impl=impl)
|
629 |
+
x = conv2d_gradfix.conv2d(
|
630 |
+
x, weight.to(dtype), stride=down, padding=0, impl=impl)
|
631 |
+
|
632 |
+
act_gain = self.act_gain * runtime_gain
|
633 |
+
act_clamp = None
|
634 |
+
if self.conv_clamp is not None:
|
635 |
+
act_clamp = self.conv_clamp * runtime_gain
|
636 |
+
x = bias_act.bias_act(x, bias,
|
637 |
+
act=self.activation_type,
|
638 |
+
gain=act_gain,
|
639 |
+
clamp=act_clamp,
|
640 |
+
impl=impl)
|
641 |
+
|
642 |
+
assert x.dtype == dtype
|
643 |
+
return x
|
644 |
+
|
645 |
+
|
646 |
+
class DenseLayer(nn.Module):
|
647 |
+
"""Implements the dense layer."""
|
648 |
+
|
649 |
+
def __init__(self,
|
650 |
+
in_channels,
|
651 |
+
out_channels,
|
652 |
+
add_bias,
|
653 |
+
init_bias,
|
654 |
+
use_wscale,
|
655 |
+
wscale_gain,
|
656 |
+
lr_mul,
|
657 |
+
activation_type):
|
658 |
+
"""Initializes with layer settings.
|
659 |
+
|
660 |
+
Args:
|
661 |
+
in_channels: Number of channels of the input tensor.
|
662 |
+
out_channels: Number of channels of the output tensor.
|
663 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
664 |
+
init_bias: The initial bias value before training.
|
665 |
+
use_wscale: Whether to use weight scaling.
|
666 |
+
wscale_gain: Gain factor for weight scaling.
|
667 |
+
lr_mul: Learning multiplier for both weight and bias.
|
668 |
+
activation_type: Type of activation.
|
669 |
+
"""
|
670 |
+
super().__init__()
|
671 |
+
self.in_channels = in_channels
|
672 |
+
self.out_channels = out_channels
|
673 |
+
self.add_bias = add_bias
|
674 |
+
self.init_bias = init_bias
|
675 |
+
self.use_wscale = use_wscale
|
676 |
+
self.wscale_gain = wscale_gain
|
677 |
+
self.lr_mul = lr_mul
|
678 |
+
self.activation_type = activation_type
|
679 |
+
|
680 |
+
weight_shape = (out_channels, in_channels)
|
681 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
682 |
+
if use_wscale:
|
683 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
684 |
+
self.wscale = wscale * lr_mul
|
685 |
+
else:
|
686 |
+
self.weight = nn.Parameter(
|
687 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
688 |
+
self.wscale = lr_mul
|
689 |
+
|
690 |
+
if add_bias:
|
691 |
+
init_bias = np.float32(init_bias) / lr_mul
|
692 |
+
self.bias = nn.Parameter(torch.full([out_channels], init_bias))
|
693 |
+
self.bscale = lr_mul
|
694 |
+
else:
|
695 |
+
self.bias = None
|
696 |
+
|
697 |
+
def extra_repr(self):
|
698 |
+
return (f'in_ch={self.in_channels}, '
|
699 |
+
f'out_ch={self.out_channels}, '
|
700 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
701 |
+
f'bias={self.add_bias}, '
|
702 |
+
f'init_bias={self.init_bias}, '
|
703 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
704 |
+
f'act={self.activation_type}')
|
705 |
+
|
706 |
+
def forward(self, x, impl='cuda'):
|
707 |
+
dtype = x.dtype
|
708 |
+
|
709 |
+
if x.ndim != 2:
|
710 |
+
x = x.flatten(start_dim=1)
|
711 |
+
|
712 |
+
weight = self.weight.to(dtype) * self.wscale
|
713 |
+
bias = None
|
714 |
+
if self.bias is not None:
|
715 |
+
bias = self.bias.to(dtype)
|
716 |
+
if self.bscale != 1.0:
|
717 |
+
bias = bias * self.bscale
|
718 |
+
|
719 |
+
# Fast pass for linear activation.
|
720 |
+
if self.activation_type == 'linear' and bias is not None:
|
721 |
+
x = torch.addmm(bias.unsqueeze(0), x, weight.t())
|
722 |
+
else:
|
723 |
+
x = x.matmul(weight.t())
|
724 |
+
x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl)
|
725 |
+
|
726 |
+
assert x.dtype == dtype
|
727 |
+
return x
|
728 |
+
|
729 |
+
# pylint: enable=missing-function-docstring
|
models/stylegan2_generator.py
ADDED
@@ -0,0 +1,1394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of generator described in StyleGAN2.
|
3 |
+
|
4 |
+
Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style
|
5 |
+
demodulation, adds skip connections, increases model size, and disables
|
6 |
+
progressive growth. This script ONLY supports config F in the original paper.
|
7 |
+
|
8 |
+
Paper: https://arxiv.org/pdf/1912.04958.pdf
|
9 |
+
|
10 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
11 |
+
"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
from third_party.stylegan2_official_ops import fma
|
19 |
+
from third_party.stylegan2_official_ops import bias_act
|
20 |
+
from third_party.stylegan2_official_ops import upfirdn2d
|
21 |
+
from third_party.stylegan2_official_ops import conv2d_gradfix
|
22 |
+
from .utils.ops import all_gather
|
23 |
+
|
24 |
+
__all__ = ['StyleGAN2Generator']
|
25 |
+
|
26 |
+
# Resolutions allowed.
|
27 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
28 |
+
|
29 |
+
# Architectures allowed.
|
30 |
+
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
31 |
+
|
32 |
+
# pylint: disable=missing-function-docstring
|
33 |
+
|
34 |
+
class StyleGAN2Generator(nn.Module):
|
35 |
+
"""Defines the generator network in StyleGAN2.
|
36 |
+
|
37 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
38 |
+
[-1, 1].
|
39 |
+
|
40 |
+
Settings for the mapping network:
|
41 |
+
|
42 |
+
(1) z_dim: Dimension of the input latent space, Z. (default: 512)
|
43 |
+
(2) w_dim: Dimension of the output latent space, W. (default: 512)
|
44 |
+
(3) repeat_w: Repeat w-code for different layers. (default: True)
|
45 |
+
(4) normalize_z: Whether to normalize the z-code. (default: True)
|
46 |
+
(5) mapping_layers: Number of layers of the mapping network. (default: 8)
|
47 |
+
(6) mapping_fmaps: Number of hidden channels of the mapping network.
|
48 |
+
(default: 512)
|
49 |
+
(7) mapping_use_wscale: Whether to use weight scaling for the mapping
|
50 |
+
network. (default: True)
|
51 |
+
(8) mapping_wscale_gain: The factor to control weight scaling for the
|
52 |
+
mapping network (default: 1.0)
|
53 |
+
(9) mapping_lr_mul: Learning rate multiplier for the mapping network.
|
54 |
+
(default: 0.01)
|
55 |
+
|
56 |
+
Settings for conditional generation:
|
57 |
+
|
58 |
+
(1) label_dim: Dimension of the additional label for conditional generation.
|
59 |
+
In one-hot conditioning case, it is equal to the number of classes. If
|
60 |
+
set to 0, conditioning training will be disabled. (default: 0)
|
61 |
+
(2) embedding_dim: Dimension of the embedding space, if needed.
|
62 |
+
(default: 512)
|
63 |
+
(3) embedding_bias: Whether to add bias to embedding learning.
|
64 |
+
(default: True)
|
65 |
+
(4) embedding_use_wscale: Whether to use weight scaling for embedding
|
66 |
+
learning. (default: True)
|
67 |
+
(5) embedding_wscale_gain: The factor to control weight scaling for
|
68 |
+
embedding. (default: 1.0)
|
69 |
+
(6) embedding_lr_mul: Learning rate multiplier for the embedding learning.
|
70 |
+
(default: 1.0)
|
71 |
+
(7) normalize_embedding: Whether to normalize the embedding. (default: True)
|
72 |
+
(8) normalize_embedding_latent: Whether to normalize the embedding together
|
73 |
+
with the latent. (default: False)
|
74 |
+
|
75 |
+
Settings for the synthesis network:
|
76 |
+
|
77 |
+
(1) resolution: The resolution of the output image. (default: -1)
|
78 |
+
(2) init_res: The initial resolution to start with convolution. (default: 4)
|
79 |
+
(3) image_channels: Number of channels of the output image. (default: 3)
|
80 |
+
(4) final_tanh: Whether to use `tanh` to control the final pixel range.
|
81 |
+
(default: False)
|
82 |
+
(5) const_input: Whether to use a constant in the first convolutional layer.
|
83 |
+
(default: True)
|
84 |
+
(6) architecture: Type of architecture. Support `origin`, `skip`, and
|
85 |
+
`resnet`. (default: `skip`)
|
86 |
+
(7) demodulate: Whether to perform style demodulation. (default: True)
|
87 |
+
(8) use_wscale: Whether to use weight scaling. (default: True)
|
88 |
+
(9) wscale_gain: The factor to control weight scaling. (default: 1.0)
|
89 |
+
(10) lr_mul: Learning rate multiplier for the synthesis network.
|
90 |
+
(default: 1.0)
|
91 |
+
(11) noise_type: Type of noise added to the convolutional results at each
|
92 |
+
layer. (default: `spatial`)
|
93 |
+
(12) fmaps_base: Factor to control number of feature maps for each layer.
|
94 |
+
(default: 32 << 10)
|
95 |
+
(13) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
96 |
+
(14) filter_kernel: Kernel used for filtering (e.g., downsampling).
|
97 |
+
(default: (1, 3, 3, 1))
|
98 |
+
(15) conv_clamp: A threshold to clamp the output of convolution layers to
|
99 |
+
avoid overflow under FP16 training. (default: None)
|
100 |
+
(16) eps: A small value to avoid divide overflow. (default: 1e-8)
|
101 |
+
|
102 |
+
Runtime settings:
|
103 |
+
|
104 |
+
(1) w_moving_decay: Decay factor for updating `w_avg`, which is used for
|
105 |
+
training only. Set `None` to disable. (default: None)
|
106 |
+
(2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set
|
107 |
+
as `True`, the stats will be more accurate, yet the speed maybe a little
|
108 |
+
bit slower. (default: False)
|
109 |
+
(3) style_mixing_prob: Probability to perform style mixing as a training
|
110 |
+
regularization. Set `None` to disable. (default: None)
|
111 |
+
(4) trunc_psi: Truncation psi, set `None` to disable. (default: None)
|
112 |
+
(5) trunc_layers: Number of layers to perform truncation. (default: None)
|
113 |
+
(6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`,
|
114 |
+
`const`. (default: `const`)
|
115 |
+
(7) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together.
|
116 |
+
(default: False)
|
117 |
+
(8) fp16_res: Layers at resolution higher than (or equal to) this field will
|
118 |
+
use `float16` precision for computation. This is merely used for
|
119 |
+
acceleration. If set as `None`, all layers will use `float32` by
|
120 |
+
default. (default: None)
|
121 |
+
(9) impl: Implementation mode of some particular ops, e.g., `filtering`,
|
122 |
+
`bias_act`, etc. `cuda` means using the official CUDA implementation
|
123 |
+
from StyleGAN2, while `ref` means using the native PyTorch ops.
|
124 |
+
(default: `cuda`)
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self,
|
128 |
+
# Settings for mapping network.
|
129 |
+
z_dim=512,
|
130 |
+
w_dim=512,
|
131 |
+
repeat_w=True,
|
132 |
+
normalize_z=True,
|
133 |
+
mapping_layers=8,
|
134 |
+
mapping_fmaps=512,
|
135 |
+
mapping_use_wscale=True,
|
136 |
+
mapping_wscale_gain=1.0,
|
137 |
+
mapping_lr_mul=0.01,
|
138 |
+
# Settings for conditional generation.
|
139 |
+
label_dim=0,
|
140 |
+
embedding_dim=512,
|
141 |
+
embedding_bias=True,
|
142 |
+
embedding_use_wscale=True,
|
143 |
+
embedding_wscale_gian=1.0,
|
144 |
+
embedding_lr_mul=1.0,
|
145 |
+
normalize_embedding=True,
|
146 |
+
normalize_embedding_latent=False,
|
147 |
+
# Settings for synthesis network.
|
148 |
+
resolution=-1,
|
149 |
+
init_res=4,
|
150 |
+
image_channels=3,
|
151 |
+
final_tanh=False,
|
152 |
+
const_input=True,
|
153 |
+
architecture='skip',
|
154 |
+
demodulate=True,
|
155 |
+
use_wscale=True,
|
156 |
+
wscale_gain=1.0,
|
157 |
+
lr_mul=1.0,
|
158 |
+
noise_type='spatial',
|
159 |
+
fmaps_base=32 << 10,
|
160 |
+
fmaps_max=512,
|
161 |
+
filter_kernel=(1, 3, 3, 1),
|
162 |
+
conv_clamp=None,
|
163 |
+
eps=1e-8):
|
164 |
+
"""Initializes with basic settings.
|
165 |
+
|
166 |
+
Raises:
|
167 |
+
ValueError: If the `resolution` is not supported, or `architecture`
|
168 |
+
is not supported.
|
169 |
+
"""
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
173 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
174 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
175 |
+
architecture = architecture.lower()
|
176 |
+
if architecture not in _ARCHITECTURES_ALLOWED:
|
177 |
+
raise ValueError(f'Invalid architecture: `{architecture}`!\n'
|
178 |
+
f'Architectures allowed: '
|
179 |
+
f'{_ARCHITECTURES_ALLOWED}.')
|
180 |
+
|
181 |
+
self.z_dim = z_dim
|
182 |
+
self.w_dim = w_dim
|
183 |
+
self.repeat_w = repeat_w
|
184 |
+
self.normalize_z = normalize_z
|
185 |
+
self.mapping_layers = mapping_layers
|
186 |
+
self.mapping_fmaps = mapping_fmaps
|
187 |
+
self.mapping_use_wscale = mapping_use_wscale
|
188 |
+
self.mapping_wscale_gain = mapping_wscale_gain
|
189 |
+
self.mapping_lr_mul = mapping_lr_mul
|
190 |
+
|
191 |
+
self.label_dim = label_dim
|
192 |
+
self.embedding_dim = embedding_dim
|
193 |
+
self.embedding_bias = embedding_bias
|
194 |
+
self.embedding_use_wscale = embedding_use_wscale
|
195 |
+
self.embedding_wscale_gain = embedding_wscale_gian
|
196 |
+
self.embedding_lr_mul = embedding_lr_mul
|
197 |
+
self.normalize_embedding = normalize_embedding
|
198 |
+
self.normalize_embedding_latent = normalize_embedding_latent
|
199 |
+
|
200 |
+
self.resolution = resolution
|
201 |
+
self.init_res = init_res
|
202 |
+
self.image_channels = image_channels
|
203 |
+
self.final_tanh = final_tanh
|
204 |
+
self.const_input = const_input
|
205 |
+
self.architecture = architecture
|
206 |
+
self.demodulate = demodulate
|
207 |
+
self.use_wscale = use_wscale
|
208 |
+
self.wscale_gain = wscale_gain
|
209 |
+
self.lr_mul = lr_mul
|
210 |
+
self.noise_type = noise_type.lower()
|
211 |
+
self.fmaps_base = fmaps_base
|
212 |
+
self.fmaps_max = fmaps_max
|
213 |
+
self.filter_kernel = filter_kernel
|
214 |
+
self.conv_clamp = conv_clamp
|
215 |
+
self.eps = eps
|
216 |
+
|
217 |
+
# Dimension of latent space, which is convenient for sampling.
|
218 |
+
self.latent_dim = (z_dim,)
|
219 |
+
|
220 |
+
# Number of synthesis (convolutional) layers.
|
221 |
+
self.num_layers = int(np.log2(resolution // init_res * 2)) * 2
|
222 |
+
|
223 |
+
self.mapping = MappingNetwork(
|
224 |
+
input_dim=z_dim,
|
225 |
+
output_dim=w_dim,
|
226 |
+
num_outputs=self.num_layers,
|
227 |
+
repeat_output=repeat_w,
|
228 |
+
normalize_input=normalize_z,
|
229 |
+
num_layers=mapping_layers,
|
230 |
+
hidden_dim=mapping_fmaps,
|
231 |
+
use_wscale=mapping_use_wscale,
|
232 |
+
wscale_gain=mapping_wscale_gain,
|
233 |
+
lr_mul=mapping_lr_mul,
|
234 |
+
label_dim=label_dim,
|
235 |
+
embedding_dim=embedding_dim,
|
236 |
+
embedding_bias=embedding_bias,
|
237 |
+
embedding_use_wscale=embedding_use_wscale,
|
238 |
+
embedding_wscale_gian=embedding_wscale_gian,
|
239 |
+
embedding_lr_mul=embedding_lr_mul,
|
240 |
+
normalize_embedding=normalize_embedding,
|
241 |
+
normalize_embedding_latent=normalize_embedding_latent,
|
242 |
+
eps=eps)
|
243 |
+
|
244 |
+
# This is used for truncation trick.
|
245 |
+
if self.repeat_w:
|
246 |
+
self.register_buffer('w_avg', torch.zeros(w_dim))
|
247 |
+
else:
|
248 |
+
self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim))
|
249 |
+
|
250 |
+
self.synthesis = SynthesisNetwork(resolution=resolution,
|
251 |
+
init_res=init_res,
|
252 |
+
w_dim=w_dim,
|
253 |
+
image_channels=image_channels,
|
254 |
+
final_tanh=final_tanh,
|
255 |
+
const_input=const_input,
|
256 |
+
architecture=architecture,
|
257 |
+
demodulate=demodulate,
|
258 |
+
use_wscale=use_wscale,
|
259 |
+
wscale_gain=wscale_gain,
|
260 |
+
lr_mul=lr_mul,
|
261 |
+
noise_type=noise_type,
|
262 |
+
fmaps_base=fmaps_base,
|
263 |
+
filter_kernel=filter_kernel,
|
264 |
+
fmaps_max=fmaps_max,
|
265 |
+
conv_clamp=conv_clamp,
|
266 |
+
eps=eps)
|
267 |
+
|
268 |
+
self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
|
269 |
+
for key, val in self.mapping.pth_to_tf_var_mapping.items():
|
270 |
+
self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
|
271 |
+
for key, val in self.synthesis.pth_to_tf_var_mapping.items():
|
272 |
+
self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
|
273 |
+
|
274 |
+
def set_space_of_latent(self, space_of_latent):
|
275 |
+
"""Sets the space to which the latent code belong.
|
276 |
+
|
277 |
+
See `SynthesisNetwork` for more details.
|
278 |
+
"""
|
279 |
+
self.synthesis.set_space_of_latent(space_of_latent)
|
280 |
+
|
281 |
+
def forward(self,
|
282 |
+
z,
|
283 |
+
label=None,
|
284 |
+
w_moving_decay=None,
|
285 |
+
sync_w_avg=False,
|
286 |
+
style_mixing_prob=None,
|
287 |
+
trunc_psi=None,
|
288 |
+
trunc_layers=None,
|
289 |
+
noise_mode='const',
|
290 |
+
fused_modulate=False,
|
291 |
+
fp16_res=None,
|
292 |
+
impl='cuda'):
|
293 |
+
"""Connects mapping network and synthesis network.
|
294 |
+
|
295 |
+
This forward function will also update the average `w_code`, perform
|
296 |
+
style mixing as a training regularizer, and do truncation trick, which
|
297 |
+
is specially designed for inference.
|
298 |
+
|
299 |
+
Concretely, the truncation trick acts as follows:
|
300 |
+
|
301 |
+
For layers in range [0, truncation_layers), the truncated w-code is
|
302 |
+
computed as
|
303 |
+
|
304 |
+
w_new = w_avg + (w - w_avg) * truncation_psi
|
305 |
+
|
306 |
+
To disable truncation, please set
|
307 |
+
|
308 |
+
(1) truncation_psi = 1.0 (None) OR
|
309 |
+
(2) truncation_layers = 0 (None)
|
310 |
+
"""
|
311 |
+
|
312 |
+
mapping_results = self.mapping(z, label, impl=impl)
|
313 |
+
|
314 |
+
w = mapping_results['w']
|
315 |
+
if self.training and w_moving_decay is not None:
|
316 |
+
if sync_w_avg:
|
317 |
+
batch_w_avg = all_gather(w.detach()).mean(dim=0)
|
318 |
+
else:
|
319 |
+
batch_w_avg = w.detach().mean(dim=0)
|
320 |
+
self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay))
|
321 |
+
|
322 |
+
wp = mapping_results.pop('wp')
|
323 |
+
if self.training and style_mixing_prob is not None:
|
324 |
+
if np.random.uniform() < style_mixing_prob:
|
325 |
+
new_z = torch.randn_like(z)
|
326 |
+
new_wp = self.mapping(new_z, label, impl=impl)['wp']
|
327 |
+
mixing_cutoff = np.random.randint(1, self.num_layers)
|
328 |
+
wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:]
|
329 |
+
|
330 |
+
if not self.training:
|
331 |
+
trunc_psi = 1.0 if trunc_psi is None else trunc_psi
|
332 |
+
trunc_layers = 0 if trunc_layers is None else trunc_layers
|
333 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
334 |
+
w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers]
|
335 |
+
wp[:, :trunc_layers] = w_avg.lerp(
|
336 |
+
wp[:, :trunc_layers], trunc_psi)
|
337 |
+
|
338 |
+
synthesis_results = self.synthesis(wp,
|
339 |
+
noise_mode=noise_mode,
|
340 |
+
fused_modulate=fused_modulate,
|
341 |
+
impl=impl,
|
342 |
+
fp16_res=fp16_res)
|
343 |
+
|
344 |
+
return {**mapping_results, **synthesis_results}
|
345 |
+
|
346 |
+
|
347 |
+
class MappingNetwork(nn.Module):
|
348 |
+
"""Implements the latent space mapping network.
|
349 |
+
|
350 |
+
Basically, this network executes several dense layers in sequence, and the
|
351 |
+
label embedding if needed.
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(self,
|
355 |
+
input_dim,
|
356 |
+
output_dim,
|
357 |
+
num_outputs,
|
358 |
+
repeat_output,
|
359 |
+
normalize_input,
|
360 |
+
num_layers,
|
361 |
+
hidden_dim,
|
362 |
+
use_wscale,
|
363 |
+
wscale_gain,
|
364 |
+
lr_mul,
|
365 |
+
label_dim,
|
366 |
+
embedding_dim,
|
367 |
+
embedding_bias,
|
368 |
+
embedding_use_wscale,
|
369 |
+
embedding_wscale_gian,
|
370 |
+
embedding_lr_mul,
|
371 |
+
normalize_embedding,
|
372 |
+
normalize_embedding_latent,
|
373 |
+
eps):
|
374 |
+
super().__init__()
|
375 |
+
|
376 |
+
self.input_dim = input_dim
|
377 |
+
self.output_dim = output_dim
|
378 |
+
self.num_outputs = num_outputs
|
379 |
+
self.repeat_output = repeat_output
|
380 |
+
self.normalize_input = normalize_input
|
381 |
+
self.num_layers = num_layers
|
382 |
+
self.hidden_dim = hidden_dim
|
383 |
+
self.use_wscale = use_wscale
|
384 |
+
self.wscale_gain = wscale_gain
|
385 |
+
self.lr_mul = lr_mul
|
386 |
+
self.label_dim = label_dim
|
387 |
+
self.embedding_dim = embedding_dim
|
388 |
+
self.embedding_bias = embedding_bias
|
389 |
+
self.embedding_use_wscale = embedding_use_wscale
|
390 |
+
self.embedding_wscale_gian = embedding_wscale_gian
|
391 |
+
self.embedding_lr_mul = embedding_lr_mul
|
392 |
+
self.normalize_embedding = normalize_embedding
|
393 |
+
self.normalize_embedding_latent = normalize_embedding_latent
|
394 |
+
self.eps = eps
|
395 |
+
|
396 |
+
self.pth_to_tf_var_mapping = {}
|
397 |
+
|
398 |
+
self.norm = PixelNormLayer(dim=1, eps=eps)
|
399 |
+
|
400 |
+
if self.label_dim > 0:
|
401 |
+
input_dim = input_dim + embedding_dim
|
402 |
+
self.embedding = DenseLayer(in_channels=label_dim,
|
403 |
+
out_channels=embedding_dim,
|
404 |
+
add_bias=embedding_bias,
|
405 |
+
init_bias=0.0,
|
406 |
+
use_wscale=embedding_use_wscale,
|
407 |
+
wscale_gain=embedding_wscale_gian,
|
408 |
+
lr_mul=embedding_lr_mul,
|
409 |
+
activation_type='linear')
|
410 |
+
self.pth_to_tf_var_mapping['embedding.weight'] = 'LabelEmbed/weight'
|
411 |
+
if self.embedding_bias:
|
412 |
+
self.pth_to_tf_var_mapping['embedding.bias'] = 'LabelEmbed/bias'
|
413 |
+
|
414 |
+
if num_outputs is not None and not repeat_output:
|
415 |
+
output_dim = output_dim * num_outputs
|
416 |
+
for i in range(num_layers):
|
417 |
+
in_channels = (input_dim if i == 0 else hidden_dim)
|
418 |
+
out_channels = (output_dim if i == (num_layers - 1) else hidden_dim)
|
419 |
+
self.add_module(f'dense{i}',
|
420 |
+
DenseLayer(in_channels=in_channels,
|
421 |
+
out_channels=out_channels,
|
422 |
+
add_bias=True,
|
423 |
+
init_bias=0.0,
|
424 |
+
use_wscale=use_wscale,
|
425 |
+
wscale_gain=wscale_gain,
|
426 |
+
lr_mul=lr_mul,
|
427 |
+
activation_type='lrelu'))
|
428 |
+
self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
|
429 |
+
self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
|
430 |
+
|
431 |
+
def forward(self, z, label=None, impl='cuda'):
|
432 |
+
if z.ndim != 2 or z.shape[1] != self.input_dim:
|
433 |
+
raise ValueError(f'Input latent code should be with shape '
|
434 |
+
f'[batch_size, input_dim], where '
|
435 |
+
f'`input_dim` equals to {self.input_dim}!\n'
|
436 |
+
f'But `{z.shape}` is received!')
|
437 |
+
if self.normalize_input:
|
438 |
+
z = self.norm(z)
|
439 |
+
|
440 |
+
if self.label_dim > 0:
|
441 |
+
if label is None:
|
442 |
+
raise ValueError(f'Model requires an additional label '
|
443 |
+
f'(with dimension {self.label_dim}) as input, '
|
444 |
+
f'but no label is received!')
|
445 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
|
446 |
+
raise ValueError(f'Input label should be with shape '
|
447 |
+
f'[batch_size, label_dim], where '
|
448 |
+
f'`batch_size` equals to that of '
|
449 |
+
f'latent codes ({z.shape[0]}) and '
|
450 |
+
f'`label_dim` equals to {self.label_dim}!\n'
|
451 |
+
f'But `{label.shape}` is received!')
|
452 |
+
label = label.to(dtype=torch.float32)
|
453 |
+
embedding = self.embedding(label, impl=impl)
|
454 |
+
if self.normalize_embedding:
|
455 |
+
embedding = self.norm(embedding)
|
456 |
+
w = torch.cat((z, embedding), dim=1)
|
457 |
+
else:
|
458 |
+
w = z
|
459 |
+
|
460 |
+
if self.label_dim > 0 and self.normalize_embedding_latent:
|
461 |
+
w = self.norm(w)
|
462 |
+
|
463 |
+
for i in range(self.num_layers):
|
464 |
+
w = getattr(self, f'dense{i}')(w, impl=impl)
|
465 |
+
|
466 |
+
wp = None
|
467 |
+
if self.num_outputs is not None:
|
468 |
+
if self.repeat_output:
|
469 |
+
wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1))
|
470 |
+
else:
|
471 |
+
wp = w.reshape(-1, self.num_outputs, self.output_dim)
|
472 |
+
|
473 |
+
results = {
|
474 |
+
'z': z,
|
475 |
+
'label': label,
|
476 |
+
'w': w,
|
477 |
+
'wp': wp,
|
478 |
+
}
|
479 |
+
if self.label_dim > 0:
|
480 |
+
results['embedding'] = embedding
|
481 |
+
return results
|
482 |
+
|
483 |
+
|
484 |
+
class SynthesisNetwork(nn.Module):
|
485 |
+
"""Implements the image synthesis network.
|
486 |
+
|
487 |
+
Basically, this network executes several convolutional layers in sequence.
|
488 |
+
"""
|
489 |
+
|
490 |
+
def __init__(self,
|
491 |
+
resolution,
|
492 |
+
init_res,
|
493 |
+
w_dim,
|
494 |
+
image_channels,
|
495 |
+
final_tanh,
|
496 |
+
const_input,
|
497 |
+
architecture,
|
498 |
+
demodulate,
|
499 |
+
use_wscale,
|
500 |
+
wscale_gain,
|
501 |
+
lr_mul,
|
502 |
+
noise_type,
|
503 |
+
fmaps_base,
|
504 |
+
fmaps_max,
|
505 |
+
filter_kernel,
|
506 |
+
conv_clamp,
|
507 |
+
eps):
|
508 |
+
super().__init__()
|
509 |
+
|
510 |
+
self.init_res = init_res
|
511 |
+
self.init_res_log2 = int(np.log2(init_res))
|
512 |
+
self.resolution = resolution
|
513 |
+
self.final_res_log2 = int(np.log2(resolution))
|
514 |
+
self.w_dim = w_dim
|
515 |
+
self.image_channels = image_channels
|
516 |
+
self.final_tanh = final_tanh
|
517 |
+
self.const_input = const_input
|
518 |
+
self.architecture = architecture.lower()
|
519 |
+
self.demodulate = demodulate
|
520 |
+
self.use_wscale = use_wscale
|
521 |
+
self.wscale_gain = wscale_gain
|
522 |
+
self.lr_mul = lr_mul
|
523 |
+
self.noise_type = noise_type.lower()
|
524 |
+
self.fmaps_base = fmaps_base
|
525 |
+
self.fmaps_max = fmaps_max
|
526 |
+
self.filter_kernel = filter_kernel
|
527 |
+
self.conv_clamp = conv_clamp
|
528 |
+
self.eps = eps
|
529 |
+
|
530 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
531 |
+
|
532 |
+
self.pth_to_tf_var_mapping = {}
|
533 |
+
|
534 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
535 |
+
res = 2 ** res_log2
|
536 |
+
in_channels = self.get_nf(res // 2)
|
537 |
+
out_channels = self.get_nf(res)
|
538 |
+
block_idx = res_log2 - self.init_res_log2
|
539 |
+
|
540 |
+
# Early layer.
|
541 |
+
if res == init_res:
|
542 |
+
if self.const_input:
|
543 |
+
self.add_module('early_layer',
|
544 |
+
InputLayer(init_res=res,
|
545 |
+
channels=out_channels))
|
546 |
+
self.pth_to_tf_var_mapping['early_layer.const'] = (
|
547 |
+
f'{res}x{res}/Const/const')
|
548 |
+
else:
|
549 |
+
channels = out_channels * res * res
|
550 |
+
self.add_module('early_layer',
|
551 |
+
DenseLayer(in_channels=w_dim,
|
552 |
+
out_channels=channels,
|
553 |
+
add_bias=True,
|
554 |
+
init_bias=0.0,
|
555 |
+
use_wscale=use_wscale,
|
556 |
+
wscale_gain=wscale_gain,
|
557 |
+
lr_mul=lr_mul,
|
558 |
+
activation_type='lrelu'))
|
559 |
+
self.pth_to_tf_var_mapping['early_layer.weight'] = (
|
560 |
+
f'{res}x{res}/Dense/weight')
|
561 |
+
self.pth_to_tf_var_mapping['early_layer.bias'] = (
|
562 |
+
f'{res}x{res}/Dense/bias')
|
563 |
+
else:
|
564 |
+
# Residual branch (kernel 1x1) with upsampling, without bias,
|
565 |
+
# with linear activation.
|
566 |
+
if self.architecture == 'resnet':
|
567 |
+
layer_name = f'residual{block_idx}'
|
568 |
+
self.add_module(layer_name,
|
569 |
+
ConvLayer(in_channels=in_channels,
|
570 |
+
out_channels=out_channels,
|
571 |
+
kernel_size=1,
|
572 |
+
add_bias=False,
|
573 |
+
scale_factor=2,
|
574 |
+
filter_kernel=filter_kernel,
|
575 |
+
use_wscale=use_wscale,
|
576 |
+
wscale_gain=wscale_gain,
|
577 |
+
lr_mul=lr_mul,
|
578 |
+
activation_type='linear',
|
579 |
+
conv_clamp=None))
|
580 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
581 |
+
f'{res}x{res}/Skip/weight')
|
582 |
+
|
583 |
+
# First layer (kernel 3x3) with upsampling.
|
584 |
+
layer_name = f'layer{2 * block_idx - 1}'
|
585 |
+
self.add_module(layer_name,
|
586 |
+
ModulateConvLayer(in_channels=in_channels,
|
587 |
+
out_channels=out_channels,
|
588 |
+
resolution=res,
|
589 |
+
w_dim=w_dim,
|
590 |
+
kernel_size=3,
|
591 |
+
add_bias=True,
|
592 |
+
scale_factor=2,
|
593 |
+
filter_kernel=filter_kernel,
|
594 |
+
demodulate=demodulate,
|
595 |
+
use_wscale=use_wscale,
|
596 |
+
wscale_gain=wscale_gain,
|
597 |
+
lr_mul=lr_mul,
|
598 |
+
noise_type=noise_type,
|
599 |
+
activation_type='lrelu',
|
600 |
+
conv_clamp=conv_clamp,
|
601 |
+
eps=eps))
|
602 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
603 |
+
f'{res}x{res}/Conv0_up/weight')
|
604 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
605 |
+
f'{res}x{res}/Conv0_up/bias')
|
606 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
607 |
+
f'{res}x{res}/Conv0_up/mod_weight')
|
608 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
609 |
+
f'{res}x{res}/Conv0_up/mod_bias')
|
610 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
611 |
+
f'{res}x{res}/Conv0_up/noise_strength')
|
612 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
613 |
+
f'noise{2 * block_idx - 1}')
|
614 |
+
|
615 |
+
# Second layer (kernel 3x3) without upsampling.
|
616 |
+
layer_name = f'layer{2 * block_idx}'
|
617 |
+
self.add_module(layer_name,
|
618 |
+
ModulateConvLayer(in_channels=out_channels,
|
619 |
+
out_channels=out_channels,
|
620 |
+
resolution=res,
|
621 |
+
w_dim=w_dim,
|
622 |
+
kernel_size=3,
|
623 |
+
add_bias=True,
|
624 |
+
scale_factor=1,
|
625 |
+
filter_kernel=None,
|
626 |
+
demodulate=demodulate,
|
627 |
+
use_wscale=use_wscale,
|
628 |
+
wscale_gain=wscale_gain,
|
629 |
+
lr_mul=lr_mul,
|
630 |
+
noise_type=noise_type,
|
631 |
+
activation_type='lrelu',
|
632 |
+
conv_clamp=conv_clamp,
|
633 |
+
eps=eps))
|
634 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
635 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
636 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
637 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
638 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
639 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
640 |
+
f'{res}x{res}/{tf_layer_name}/mod_weight')
|
641 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
642 |
+
f'{res}x{res}/{tf_layer_name}/mod_bias')
|
643 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
644 |
+
f'{res}x{res}/{tf_layer_name}/noise_strength')
|
645 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
646 |
+
f'noise{2 * block_idx}')
|
647 |
+
|
648 |
+
# Output convolution layer for each resolution (if needed).
|
649 |
+
if res_log2 == self.final_res_log2 or self.architecture == 'skip':
|
650 |
+
layer_name = f'output{block_idx}'
|
651 |
+
self.add_module(layer_name,
|
652 |
+
ModulateConvLayer(in_channels=out_channels,
|
653 |
+
out_channels=image_channels,
|
654 |
+
resolution=res,
|
655 |
+
w_dim=w_dim,
|
656 |
+
kernel_size=1,
|
657 |
+
add_bias=True,
|
658 |
+
scale_factor=1,
|
659 |
+
filter_kernel=None,
|
660 |
+
demodulate=False,
|
661 |
+
use_wscale=use_wscale,
|
662 |
+
wscale_gain=wscale_gain,
|
663 |
+
lr_mul=lr_mul,
|
664 |
+
noise_type='none',
|
665 |
+
activation_type='linear',
|
666 |
+
conv_clamp=conv_clamp,
|
667 |
+
eps=eps))
|
668 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
669 |
+
f'{res}x{res}/ToRGB/weight')
|
670 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
671 |
+
f'{res}x{res}/ToRGB/bias')
|
672 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
673 |
+
f'{res}x{res}/ToRGB/mod_weight')
|
674 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
675 |
+
f'{res}x{res}/ToRGB/mod_bias')
|
676 |
+
|
677 |
+
# Used for upsampling output images for each resolution block for sum.
|
678 |
+
if self.architecture == 'skip':
|
679 |
+
self.register_buffer(
|
680 |
+
'filter', upfirdn2d.setup_filter(filter_kernel))
|
681 |
+
|
682 |
+
def get_nf(self, res):
|
683 |
+
"""Gets number of feature maps according to the given resolution."""
|
684 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
685 |
+
|
686 |
+
def set_space_of_latent(self, space_of_latent):
|
687 |
+
"""Sets the space to which the latent code belong.
|
688 |
+
|
689 |
+
This function is particularly used for choosing how to inject the latent
|
690 |
+
code into the convolutional layers. The original generator will take a
|
691 |
+
W-Space code and apply it for style modulation after an affine
|
692 |
+
transformation. But, sometimes, it may need to directly feed an already
|
693 |
+
affine-transformed code into the convolutional layer, e.g., when
|
694 |
+
training an encoder for GAN inversion. We term the transformed space as
|
695 |
+
Style Space (or Y-Space). This function is designed to tell the
|
696 |
+
convolutional layers how to use the input code.
|
697 |
+
|
698 |
+
Args:
|
699 |
+
space_of_latent: The space to which the latent code belong. Case
|
700 |
+
insensitive. Support `W` and `Y`.
|
701 |
+
"""
|
702 |
+
space_of_latent = space_of_latent.upper()
|
703 |
+
for module in self.modules():
|
704 |
+
if isinstance(module, ModulateConvLayer):
|
705 |
+
setattr(module, 'space_of_latent', space_of_latent)
|
706 |
+
|
707 |
+
def forward(self,
|
708 |
+
wp,
|
709 |
+
noise_mode='const',
|
710 |
+
fused_modulate=False,
|
711 |
+
fp16_res=None,
|
712 |
+
impl='cuda'):
|
713 |
+
results = {'wp': wp}
|
714 |
+
|
715 |
+
if self.const_input:
|
716 |
+
x = self.early_layer(wp[:, 0])
|
717 |
+
else:
|
718 |
+
x = self.early_layer(wp[:, 0], impl=impl)
|
719 |
+
|
720 |
+
# Cast to `torch.float16` if needed.
|
721 |
+
if fp16_res is not None and self.init_res >= fp16_res:
|
722 |
+
x = x.to(torch.float16)
|
723 |
+
|
724 |
+
if self.architecture == 'origin':
|
725 |
+
for layer_idx in range(self.num_layers - 1):
|
726 |
+
layer = getattr(self, f'layer{layer_idx}')
|
727 |
+
x, style = layer(x,
|
728 |
+
wp[:, layer_idx],
|
729 |
+
noise_mode=noise_mode,
|
730 |
+
fused_modulate=fused_modulate,
|
731 |
+
impl=impl)
|
732 |
+
results[f'style{layer_idx}'] = style
|
733 |
+
|
734 |
+
# Cast to `torch.float16` if needed.
|
735 |
+
if layer_idx % 2 == 0 and layer_idx != self.num_layers - 2:
|
736 |
+
res = self.init_res * (2 ** (layer_idx // 2))
|
737 |
+
if fp16_res is not None and res * 2 >= fp16_res:
|
738 |
+
x = x.to(torch.float16)
|
739 |
+
else:
|
740 |
+
x = x.to(torch.float32)
|
741 |
+
output_layer = getattr(self, f'output{layer_idx // 2}')
|
742 |
+
image, style = output_layer(x,
|
743 |
+
wp[:, layer_idx + 1],
|
744 |
+
fused_modulate=fused_modulate,
|
745 |
+
impl=impl)
|
746 |
+
image = image.to(torch.float32)
|
747 |
+
results[f'output_style{layer_idx // 2}'] = style
|
748 |
+
|
749 |
+
elif self.architecture == 'skip':
|
750 |
+
for layer_idx in range(self.num_layers - 1):
|
751 |
+
layer = getattr(self, f'layer{layer_idx}')
|
752 |
+
x, style = layer(x,
|
753 |
+
wp[:, layer_idx],
|
754 |
+
noise_mode=noise_mode,
|
755 |
+
fused_modulate=fused_modulate,
|
756 |
+
impl=impl)
|
757 |
+
results[f'style{layer_idx}'] = style
|
758 |
+
if layer_idx % 2 == 0:
|
759 |
+
output_layer = getattr(self, f'output{layer_idx // 2}')
|
760 |
+
y, style = output_layer(x,
|
761 |
+
wp[:, layer_idx + 1],
|
762 |
+
fused_modulate=fused_modulate,
|
763 |
+
impl=impl)
|
764 |
+
results[f'output_style{layer_idx // 2}'] = style
|
765 |
+
if layer_idx == 0:
|
766 |
+
image = y.to(torch.float32)
|
767 |
+
else:
|
768 |
+
image = y.to(torch.float32) + upfirdn2d.upsample2d(
|
769 |
+
image, self.filter, impl=impl)
|
770 |
+
|
771 |
+
# Cast to `torch.float16` if needed.
|
772 |
+
if layer_idx != self.num_layers - 2:
|
773 |
+
res = self.init_res * (2 ** (layer_idx // 2))
|
774 |
+
if fp16_res is not None and res * 2 >= fp16_res:
|
775 |
+
x = x.to(torch.float16)
|
776 |
+
else:
|
777 |
+
x = x.to(torch.float32)
|
778 |
+
|
779 |
+
elif self.architecture == 'resnet':
|
780 |
+
x, style = self.layer0(x,
|
781 |
+
wp[:, 0],
|
782 |
+
noise_mode=noise_mode,
|
783 |
+
fused_modulate=fused_modulate,
|
784 |
+
impl=impl)
|
785 |
+
results['style0'] = style
|
786 |
+
for layer_idx in range(1, self.num_layers - 1, 2):
|
787 |
+
# Cast to `torch.float16` if needed.
|
788 |
+
if layer_idx % 2 == 1:
|
789 |
+
res = self.init_res * (2 ** (layer_idx // 2))
|
790 |
+
if fp16_res is not None and res * 2 >= fp16_res:
|
791 |
+
x = x.to(torch.float16)
|
792 |
+
else:
|
793 |
+
x = x.to(torch.float32)
|
794 |
+
|
795 |
+
skip_layer = getattr(self, f'residual{layer_idx // 2 + 1}')
|
796 |
+
residual = skip_layer(x, runtime_gain=np.sqrt(0.5), impl=impl)
|
797 |
+
layer = getattr(self, f'layer{layer_idx}')
|
798 |
+
x, style = layer(x,
|
799 |
+
wp[:, layer_idx],
|
800 |
+
noise_mode=noise_mode,
|
801 |
+
fused_modulate=fused_modulate,
|
802 |
+
impl=impl)
|
803 |
+
results[f'style{layer_idx}'] = style
|
804 |
+
layer = getattr(self, f'layer{layer_idx + 1}')
|
805 |
+
x, style = layer(x,
|
806 |
+
wp[:, layer_idx + 1],
|
807 |
+
runtime_gain=np.sqrt(0.5),
|
808 |
+
noise_mode=noise_mode,
|
809 |
+
fused_modulate=fused_modulate,
|
810 |
+
impl=impl)
|
811 |
+
results[f'style{layer_idx + 1}'] = style
|
812 |
+
x = x + residual
|
813 |
+
output_layer = getattr(self, f'output{layer_idx // 2 + 1}')
|
814 |
+
image, style = output_layer(x,
|
815 |
+
wp[:, layer_idx + 2],
|
816 |
+
fused_modulate=fused_modulate,
|
817 |
+
impl=impl)
|
818 |
+
image = image.to(torch.float32)
|
819 |
+
results[f'output_style{layer_idx // 2}'] = style
|
820 |
+
|
821 |
+
if self.final_tanh:
|
822 |
+
image = torch.tanh(image)
|
823 |
+
results['image'] = image
|
824 |
+
return results
|
825 |
+
|
826 |
+
|
827 |
+
class PixelNormLayer(nn.Module):
|
828 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
829 |
+
|
830 |
+
def __init__(self, dim, eps):
|
831 |
+
super().__init__()
|
832 |
+
self.dim = dim
|
833 |
+
self.eps = eps
|
834 |
+
|
835 |
+
def extra_repr(self):
|
836 |
+
return f'dim={self.dim}, epsilon={self.eps}'
|
837 |
+
|
838 |
+
def forward(self, x):
|
839 |
+
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
|
840 |
+
return x * scale
|
841 |
+
|
842 |
+
|
843 |
+
class InputLayer(nn.Module):
|
844 |
+
"""Implements the input layer to start convolution with.
|
845 |
+
|
846 |
+
Basically, this block starts from a const input, which is with shape
|
847 |
+
`(channels, init_res, init_res)`.
|
848 |
+
"""
|
849 |
+
|
850 |
+
def __init__(self, init_res, channels):
|
851 |
+
super().__init__()
|
852 |
+
self.const = nn.Parameter(torch.randn(1, channels, init_res, init_res))
|
853 |
+
|
854 |
+
def forward(self, w):
|
855 |
+
x = self.const.repeat(w.shape[0], 1, 1, 1)
|
856 |
+
return x
|
857 |
+
|
858 |
+
|
859 |
+
class ConvLayer(nn.Module):
|
860 |
+
"""Implements the convolutional layer.
|
861 |
+
|
862 |
+
If upsampling is needed (i.e., `scale_factor = 2`), the feature map will
|
863 |
+
be filtered with `filter_kernel` after convolution. This layer will only be
|
864 |
+
used for skip connection in `resnet` architecture.
|
865 |
+
"""
|
866 |
+
|
867 |
+
def __init__(self,
|
868 |
+
in_channels,
|
869 |
+
out_channels,
|
870 |
+
kernel_size,
|
871 |
+
add_bias,
|
872 |
+
scale_factor,
|
873 |
+
filter_kernel,
|
874 |
+
use_wscale,
|
875 |
+
wscale_gain,
|
876 |
+
lr_mul,
|
877 |
+
activation_type,
|
878 |
+
conv_clamp):
|
879 |
+
"""Initializes with layer settings.
|
880 |
+
|
881 |
+
Args:
|
882 |
+
in_channels: Number of channels of the input tensor.
|
883 |
+
out_channels: Number of channels of the output tensor.
|
884 |
+
kernel_size: Size of the convolutional kernels.
|
885 |
+
add_bias: Whether to add bias onto the convolutional result.
|
886 |
+
scale_factor: Scale factor for upsampling.
|
887 |
+
filter_kernel: Kernel used for filtering.
|
888 |
+
use_wscale: Whether to use weight scaling.
|
889 |
+
wscale_gain: Gain factor for weight scaling.
|
890 |
+
lr_mul: Learning multiplier for both weight and bias.
|
891 |
+
activation_type: Type of activation.
|
892 |
+
conv_clamp: A threshold to clamp the output of convolution layers to
|
893 |
+
avoid overflow under FP16 training.
|
894 |
+
"""
|
895 |
+
super().__init__()
|
896 |
+
self.in_channels = in_channels
|
897 |
+
self.out_channels = out_channels
|
898 |
+
self.kernel_size = kernel_size
|
899 |
+
self.add_bias = add_bias
|
900 |
+
self.scale_factor = scale_factor
|
901 |
+
self.filter_kernel = filter_kernel
|
902 |
+
self.use_wscale = use_wscale
|
903 |
+
self.wscale_gain = wscale_gain
|
904 |
+
self.lr_mul = lr_mul
|
905 |
+
self.activation_type = activation_type
|
906 |
+
self.conv_clamp = conv_clamp
|
907 |
+
|
908 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
909 |
+
fan_in = kernel_size * kernel_size * in_channels
|
910 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
911 |
+
if use_wscale:
|
912 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
913 |
+
self.wscale = wscale * lr_mul
|
914 |
+
else:
|
915 |
+
self.weight = nn.Parameter(
|
916 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
917 |
+
self.wscale = lr_mul
|
918 |
+
|
919 |
+
if add_bias:
|
920 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
921 |
+
self.bscale = lr_mul
|
922 |
+
else:
|
923 |
+
self.bias = None
|
924 |
+
self.act_gain = bias_act.activation_funcs[activation_type].def_gain
|
925 |
+
|
926 |
+
if scale_factor > 1:
|
927 |
+
assert filter_kernel is not None
|
928 |
+
self.register_buffer(
|
929 |
+
'filter', upfirdn2d.setup_filter(filter_kernel))
|
930 |
+
fh, fw = self.filter.shape
|
931 |
+
self.filter_padding = (
|
932 |
+
kernel_size // 2 + (fw + scale_factor - 1) // 2,
|
933 |
+
kernel_size // 2 + (fw - scale_factor) // 2,
|
934 |
+
kernel_size // 2 + (fh + scale_factor - 1) // 2,
|
935 |
+
kernel_size // 2 + (fh - scale_factor) // 2)
|
936 |
+
|
937 |
+
def extra_repr(self):
|
938 |
+
return (f'in_ch={self.in_channels}, '
|
939 |
+
f'out_ch={self.out_channels}, '
|
940 |
+
f'ksize={self.kernel_size}, '
|
941 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
942 |
+
f'bias={self.add_bias}, '
|
943 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
944 |
+
f'upsample={self.scale_factor}, '
|
945 |
+
f'upsample_filter={self.filter_kernel}, '
|
946 |
+
f'act={self.activation_type}, '
|
947 |
+
f'clamp={self.conv_clamp}')
|
948 |
+
|
949 |
+
def forward(self, x, runtime_gain=1.0, impl='cuda'):
|
950 |
+
dtype = x.dtype
|
951 |
+
|
952 |
+
weight = self.weight
|
953 |
+
if self.wscale != 1.0:
|
954 |
+
weight = weight * self.wscale
|
955 |
+
bias = None
|
956 |
+
if self.bias is not None:
|
957 |
+
bias = self.bias.to(dtype)
|
958 |
+
if self.bscale != 1.0:
|
959 |
+
bias = bias * self.bscale
|
960 |
+
|
961 |
+
if self.scale_factor == 1: # Native convolution without upsampling.
|
962 |
+
padding = self.kernel_size // 2
|
963 |
+
x = conv2d_gradfix.conv2d(
|
964 |
+
x, weight.to(dtype), stride=1, padding=padding, impl=impl)
|
965 |
+
else: # Convolution with upsampling.
|
966 |
+
up = self.scale_factor
|
967 |
+
f = self.filter
|
968 |
+
# When kernel size = 1, use filtering function for upsampling.
|
969 |
+
if self.kernel_size == 1:
|
970 |
+
padding = self.filter_padding
|
971 |
+
x = conv2d_gradfix.conv2d(
|
972 |
+
x, weight.to(dtype), stride=1, padding=0, impl=impl)
|
973 |
+
x = upfirdn2d.upfirdn2d(
|
974 |
+
x, f, up=up, padding=padding, gain=up ** 2, impl=impl)
|
975 |
+
# When kernel size != 1, use transpose convolution for upsampling.
|
976 |
+
else:
|
977 |
+
# Following codes are borrowed from
|
978 |
+
# https://github.com/NVlabs/stylegan2-ada-pytorch
|
979 |
+
px0, px1, py0, py1 = self.filter_padding
|
980 |
+
kh, kw = weight.shape[2:]
|
981 |
+
px0 = px0 - (kw - 1)
|
982 |
+
px1 = px1 - (kw - up)
|
983 |
+
py0 = py0 - (kh - 1)
|
984 |
+
py1 = py1 - (kh - up)
|
985 |
+
pxt = max(min(-px0, -px1), 0)
|
986 |
+
pyt = max(min(-py0, -py1), 0)
|
987 |
+
weight = weight.transpose(0, 1)
|
988 |
+
padding = (pyt, pxt)
|
989 |
+
x = conv2d_gradfix.conv_transpose2d(
|
990 |
+
x, weight.to(dtype), stride=up, padding=padding, impl=impl)
|
991 |
+
padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt)
|
992 |
+
x = upfirdn2d.upfirdn2d(
|
993 |
+
x, f, up=1, padding=padding, gain=up ** 2, impl=impl)
|
994 |
+
|
995 |
+
act_gain = self.act_gain * runtime_gain
|
996 |
+
act_clamp = None
|
997 |
+
if self.conv_clamp is not None:
|
998 |
+
act_clamp = self.conv_clamp * runtime_gain
|
999 |
+
x = bias_act.bias_act(x, bias,
|
1000 |
+
act=self.activation_type,
|
1001 |
+
gain=act_gain,
|
1002 |
+
clamp=act_clamp,
|
1003 |
+
impl=impl)
|
1004 |
+
|
1005 |
+
assert x.dtype == dtype
|
1006 |
+
return x
|
1007 |
+
|
1008 |
+
|
1009 |
+
class ModulateConvLayer(nn.Module):
|
1010 |
+
"""Implements the convolutional layer with style modulation."""
|
1011 |
+
|
1012 |
+
def __init__(self,
|
1013 |
+
in_channels,
|
1014 |
+
out_channels,
|
1015 |
+
resolution,
|
1016 |
+
w_dim,
|
1017 |
+
kernel_size,
|
1018 |
+
add_bias,
|
1019 |
+
scale_factor,
|
1020 |
+
filter_kernel,
|
1021 |
+
demodulate,
|
1022 |
+
use_wscale,
|
1023 |
+
wscale_gain,
|
1024 |
+
lr_mul,
|
1025 |
+
noise_type,
|
1026 |
+
activation_type,
|
1027 |
+
conv_clamp,
|
1028 |
+
eps):
|
1029 |
+
"""Initializes with layer settings.
|
1030 |
+
|
1031 |
+
Args:
|
1032 |
+
in_channels: Number of channels of the input tensor.
|
1033 |
+
out_channels: Number of channels of the output tensor.
|
1034 |
+
resolution: Resolution of the output tensor.
|
1035 |
+
w_dim: Dimension of W space for style modulation.
|
1036 |
+
kernel_size: Size of the convolutional kernels.
|
1037 |
+
add_bias: Whether to add bias onto the convolutional result.
|
1038 |
+
scale_factor: Scale factor for upsampling.
|
1039 |
+
filter_kernel: Kernel used for filtering.
|
1040 |
+
demodulate: Whether to perform style demodulation.
|
1041 |
+
use_wscale: Whether to use weight scaling.
|
1042 |
+
wscale_gain: Gain factor for weight scaling.
|
1043 |
+
lr_mul: Learning multiplier for both weight and bias.
|
1044 |
+
noise_type: Type of noise added to the feature map after the
|
1045 |
+
convolution (if needed). Support `none`, `spatial` and
|
1046 |
+
`channel`.
|
1047 |
+
activation_type: Type of activation.
|
1048 |
+
conv_clamp: A threshold to clamp the output of convolution layers to
|
1049 |
+
avoid overflow under FP16 training.
|
1050 |
+
eps: A small value to avoid divide overflow.
|
1051 |
+
"""
|
1052 |
+
super().__init__()
|
1053 |
+
|
1054 |
+
self.in_channels = in_channels
|
1055 |
+
self.out_channels = out_channels
|
1056 |
+
self.resolution = resolution
|
1057 |
+
self.w_dim = w_dim
|
1058 |
+
self.kernel_size = kernel_size
|
1059 |
+
self.add_bias = add_bias
|
1060 |
+
self.scale_factor = scale_factor
|
1061 |
+
self.filter_kernel = filter_kernel
|
1062 |
+
self.demodulate = demodulate
|
1063 |
+
self.use_wscale = use_wscale
|
1064 |
+
self.wscale_gain = wscale_gain
|
1065 |
+
self.lr_mul = lr_mul
|
1066 |
+
self.noise_type = noise_type.lower()
|
1067 |
+
self.activation_type = activation_type
|
1068 |
+
self.conv_clamp = conv_clamp
|
1069 |
+
self.eps = eps
|
1070 |
+
|
1071 |
+
self.space_of_latent = 'W'
|
1072 |
+
|
1073 |
+
# Set up weight.
|
1074 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
1075 |
+
fan_in = kernel_size * kernel_size * in_channels
|
1076 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
1077 |
+
if use_wscale:
|
1078 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
1079 |
+
self.wscale = wscale * lr_mul
|
1080 |
+
else:
|
1081 |
+
self.weight = nn.Parameter(
|
1082 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
1083 |
+
self.wscale = lr_mul
|
1084 |
+
|
1085 |
+
# Set up bias.
|
1086 |
+
if add_bias:
|
1087 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
1088 |
+
self.bscale = lr_mul
|
1089 |
+
else:
|
1090 |
+
self.bias = None
|
1091 |
+
self.act_gain = bias_act.activation_funcs[activation_type].def_gain
|
1092 |
+
|
1093 |
+
# Set up style.
|
1094 |
+
self.style = DenseLayer(in_channels=w_dim,
|
1095 |
+
out_channels=in_channels,
|
1096 |
+
add_bias=True,
|
1097 |
+
init_bias=1.0,
|
1098 |
+
use_wscale=use_wscale,
|
1099 |
+
wscale_gain=wscale_gain,
|
1100 |
+
lr_mul=lr_mul,
|
1101 |
+
activation_type='linear')
|
1102 |
+
|
1103 |
+
# Set up noise.
|
1104 |
+
if self.noise_type != 'none':
|
1105 |
+
self.noise_strength = nn.Parameter(torch.zeros(()))
|
1106 |
+
if self.noise_type == 'spatial':
|
1107 |
+
self.register_buffer(
|
1108 |
+
'noise', torch.randn(1, 1, resolution, resolution))
|
1109 |
+
elif self.noise_type == 'channel':
|
1110 |
+
self.register_buffer(
|
1111 |
+
'noise', torch.randn(1, out_channels, 1, 1))
|
1112 |
+
else:
|
1113 |
+
raise NotImplementedError(f'Not implemented noise type: '
|
1114 |
+
f'`{self.noise_type}`!')
|
1115 |
+
|
1116 |
+
if scale_factor > 1:
|
1117 |
+
assert filter_kernel is not None
|
1118 |
+
self.register_buffer(
|
1119 |
+
'filter', upfirdn2d.setup_filter(filter_kernel))
|
1120 |
+
fh, fw = self.filter.shape
|
1121 |
+
self.filter_padding = (
|
1122 |
+
kernel_size // 2 + (fw + scale_factor - 1) // 2,
|
1123 |
+
kernel_size // 2 + (fw - scale_factor) // 2,
|
1124 |
+
kernel_size // 2 + (fh + scale_factor - 1) // 2,
|
1125 |
+
kernel_size // 2 + (fh - scale_factor) // 2)
|
1126 |
+
|
1127 |
+
def extra_repr(self):
|
1128 |
+
return (f'in_ch={self.in_channels}, '
|
1129 |
+
f'out_ch={self.out_channels}, '
|
1130 |
+
f'ksize={self.kernel_size}, '
|
1131 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
1132 |
+
f'bias={self.add_bias}, '
|
1133 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
1134 |
+
f'upsample={self.scale_factor}, '
|
1135 |
+
f'upsample_filter={self.filter_kernel}, '
|
1136 |
+
f'demodulate={self.demodulate}, '
|
1137 |
+
f'noise_type={self.noise_type}, '
|
1138 |
+
f'act={self.activation_type}, '
|
1139 |
+
f'clamp={self.conv_clamp}')
|
1140 |
+
|
1141 |
+
def forward_style(self, w, impl='cuda'):
|
1142 |
+
"""Gets style code from the given input.
|
1143 |
+
|
1144 |
+
More specifically, if the input is from W-Space, it will be projected by
|
1145 |
+
an affine transformation. If it is from the Style Space (Y-Space), no
|
1146 |
+
operation is required.
|
1147 |
+
|
1148 |
+
NOTE: For codes from Y-Space, we use slicing to make sure the dimension
|
1149 |
+
is correct, in case that the code is padded before fed into this layer.
|
1150 |
+
"""
|
1151 |
+
space_of_latent = self.space_of_latent.upper()
|
1152 |
+
if space_of_latent == 'W':
|
1153 |
+
if w.ndim != 2 or w.shape[1] != self.w_dim:
|
1154 |
+
raise ValueError(f'The input tensor should be with shape '
|
1155 |
+
f'[batch_size, w_dim], where '
|
1156 |
+
f'`w_dim` equals to {self.w_dim}!\n'
|
1157 |
+
f'But `{w.shape}` is received!')
|
1158 |
+
style = self.style(w, impl=impl)
|
1159 |
+
elif space_of_latent == 'Y':
|
1160 |
+
if w.ndim != 2 or w.shape[1] < self.in_channels:
|
1161 |
+
raise ValueError(f'The input tensor should be with shape '
|
1162 |
+
f'[batch_size, y_dim], where '
|
1163 |
+
f'`y_dim` equals to {self.in_channels}!\n'
|
1164 |
+
f'But `{w.shape}` is received!')
|
1165 |
+
style = w[:, :self.in_channels]
|
1166 |
+
else:
|
1167 |
+
raise NotImplementedError(f'Not implemented `space_of_latent`: '
|
1168 |
+
f'`{space_of_latent}`!')
|
1169 |
+
return style
|
1170 |
+
|
1171 |
+
def forward(self,
|
1172 |
+
x,
|
1173 |
+
w,
|
1174 |
+
runtime_gain=1.0,
|
1175 |
+
noise_mode='const',
|
1176 |
+
fused_modulate=False,
|
1177 |
+
impl='cuda'):
|
1178 |
+
dtype = x.dtype
|
1179 |
+
N, C, H, W = x.shape
|
1180 |
+
|
1181 |
+
fused_modulate = (fused_modulate and
|
1182 |
+
not self.training and
|
1183 |
+
(dtype == torch.float32 or N == 1))
|
1184 |
+
|
1185 |
+
weight = self.weight
|
1186 |
+
out_ch, in_ch, kh, kw = weight.shape
|
1187 |
+
assert in_ch == C
|
1188 |
+
|
1189 |
+
# Affine on `w`.
|
1190 |
+
style = self.forward_style(w, impl=impl)
|
1191 |
+
if not self.demodulate:
|
1192 |
+
_style = style * self.wscale # Equivalent to scaling weight.
|
1193 |
+
else:
|
1194 |
+
_style = style
|
1195 |
+
|
1196 |
+
# Prepare noise.
|
1197 |
+
noise = None
|
1198 |
+
noise_mode = noise_mode.lower()
|
1199 |
+
if self.noise_type != 'none' and noise_mode != 'none':
|
1200 |
+
if noise_mode == 'random':
|
1201 |
+
noise = torch.randn((N, *self.noise.shape[1:]), device=x.device)
|
1202 |
+
elif noise_mode == 'const':
|
1203 |
+
noise = self.noise
|
1204 |
+
else:
|
1205 |
+
raise ValueError(f'Unknown noise mode `{noise_mode}`!')
|
1206 |
+
noise = (noise * self.noise_strength).to(dtype)
|
1207 |
+
|
1208 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
1209 |
+
if dtype == torch.float16 and self.demodulate:
|
1210 |
+
weight_max = weight.norm(float('inf'), dim=(1, 2, 3), keepdim=True)
|
1211 |
+
weight = weight * (self.wscale / weight_max)
|
1212 |
+
style_max = _style.norm(float('inf'), dim=1, keepdim=True)
|
1213 |
+
_style = _style / style_max
|
1214 |
+
|
1215 |
+
if self.demodulate or fused_modulate:
|
1216 |
+
_weight = weight.unsqueeze(0)
|
1217 |
+
_weight = _weight * _style.reshape(N, 1, in_ch, 1, 1)
|
1218 |
+
if self.demodulate:
|
1219 |
+
decoef = (_weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt()
|
1220 |
+
if self.demodulate and fused_modulate:
|
1221 |
+
_weight = _weight * decoef.reshape(N, out_ch, 1, 1, 1)
|
1222 |
+
|
1223 |
+
if not fused_modulate:
|
1224 |
+
x = x * _style.to(dtype).reshape(N, in_ch, 1, 1)
|
1225 |
+
w = weight.to(dtype)
|
1226 |
+
groups = 1
|
1227 |
+
else: # Use group convolution to fuse style modulation and convolution.
|
1228 |
+
x = x.reshape(1, N * in_ch, H, W)
|
1229 |
+
w = _weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype)
|
1230 |
+
groups = N
|
1231 |
+
|
1232 |
+
if self.scale_factor == 1: # Native convolution without upsampling.
|
1233 |
+
up = 1
|
1234 |
+
padding = self.kernel_size // 2
|
1235 |
+
x = conv2d_gradfix.conv2d(
|
1236 |
+
x, w, stride=1, padding=padding, groups=groups, impl=impl)
|
1237 |
+
else: # Convolution with upsampling.
|
1238 |
+
up = self.scale_factor
|
1239 |
+
f = self.filter
|
1240 |
+
# When kernel size = 1, use filtering function for upsampling.
|
1241 |
+
if self.kernel_size == 1:
|
1242 |
+
padding = self.filter_padding
|
1243 |
+
x = conv2d_gradfix.conv2d(
|
1244 |
+
x, w, stride=1, padding=0, groups=groups, impl=impl)
|
1245 |
+
x = upfirdn2d.upfirdn2d(
|
1246 |
+
x, f, up=up, padding=padding, gain=up ** 2, impl=impl)
|
1247 |
+
# When kernel size != 1, use stride convolution for upsampling.
|
1248 |
+
else:
|
1249 |
+
# Following codes are borrowed from
|
1250 |
+
# https://github.com/NVlabs/stylegan2-ada-pytorch
|
1251 |
+
px0, px1, py0, py1 = self.filter_padding
|
1252 |
+
px0 = px0 - (kw - 1)
|
1253 |
+
px1 = px1 - (kw - up)
|
1254 |
+
py0 = py0 - (kh - 1)
|
1255 |
+
py1 = py1 - (kh - up)
|
1256 |
+
pxt = max(min(-px0, -px1), 0)
|
1257 |
+
pyt = max(min(-py0, -py1), 0)
|
1258 |
+
if groups == 1:
|
1259 |
+
w = w.transpose(0, 1)
|
1260 |
+
else:
|
1261 |
+
w = w.reshape(N, out_ch, in_ch, kh, kw)
|
1262 |
+
w = w.transpose(1, 2)
|
1263 |
+
w = w.reshape(N * in_ch, out_ch, kh, kw)
|
1264 |
+
padding = (pyt, pxt)
|
1265 |
+
x = conv2d_gradfix.conv_transpose2d(
|
1266 |
+
x, w, stride=up, padding=padding, groups=groups, impl=impl)
|
1267 |
+
padding = (px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt)
|
1268 |
+
x = upfirdn2d.upfirdn2d(
|
1269 |
+
x, f, up=1, padding=padding, gain=up ** 2, impl=impl)
|
1270 |
+
|
1271 |
+
if not fused_modulate:
|
1272 |
+
if self.demodulate:
|
1273 |
+
decoef = decoef.to(dtype).reshape(N, out_ch, 1, 1)
|
1274 |
+
if self.demodulate and noise is not None:
|
1275 |
+
x = fma.fma(x, decoef, noise, impl=impl)
|
1276 |
+
else:
|
1277 |
+
if self.demodulate:
|
1278 |
+
x = x * decoef
|
1279 |
+
if noise is not None:
|
1280 |
+
x = x + noise
|
1281 |
+
else:
|
1282 |
+
x = x.reshape(N, out_ch, H * up, W * up)
|
1283 |
+
if noise is not None:
|
1284 |
+
x = x + noise
|
1285 |
+
|
1286 |
+
bias = None
|
1287 |
+
if self.bias is not None:
|
1288 |
+
bias = self.bias.to(dtype)
|
1289 |
+
if self.bscale != 1.0:
|
1290 |
+
bias = bias * self.bscale
|
1291 |
+
|
1292 |
+
if self.activation_type == 'linear': # Shortcut for output layer.
|
1293 |
+
x = bias_act.bias_act(
|
1294 |
+
x, bias, act='linear', clamp=self.conv_clamp, impl=impl)
|
1295 |
+
else:
|
1296 |
+
act_gain = self.act_gain * runtime_gain
|
1297 |
+
act_clamp = None
|
1298 |
+
if self.conv_clamp is not None:
|
1299 |
+
act_clamp = self.conv_clamp * runtime_gain
|
1300 |
+
x = bias_act.bias_act(x, bias,
|
1301 |
+
act=self.activation_type,
|
1302 |
+
gain=act_gain,
|
1303 |
+
clamp=act_clamp,
|
1304 |
+
impl=impl)
|
1305 |
+
|
1306 |
+
assert x.dtype == dtype
|
1307 |
+
assert style.dtype == torch.float32
|
1308 |
+
return x, style
|
1309 |
+
|
1310 |
+
|
1311 |
+
class DenseLayer(nn.Module):
|
1312 |
+
"""Implements the dense layer."""
|
1313 |
+
|
1314 |
+
def __init__(self,
|
1315 |
+
in_channels,
|
1316 |
+
out_channels,
|
1317 |
+
add_bias,
|
1318 |
+
init_bias,
|
1319 |
+
use_wscale,
|
1320 |
+
wscale_gain,
|
1321 |
+
lr_mul,
|
1322 |
+
activation_type):
|
1323 |
+
"""Initializes with layer settings.
|
1324 |
+
|
1325 |
+
Args:
|
1326 |
+
in_channels: Number of channels of the input tensor.
|
1327 |
+
out_channels: Number of channels of the output tensor.
|
1328 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
1329 |
+
init_bias: The initial bias value before training.
|
1330 |
+
use_wscale: Whether to use weight scaling.
|
1331 |
+
wscale_gain: Gain factor for weight scaling.
|
1332 |
+
lr_mul: Learning multiplier for both weight and bias.
|
1333 |
+
activation_type: Type of activation.
|
1334 |
+
"""
|
1335 |
+
super().__init__()
|
1336 |
+
self.in_channels = in_channels
|
1337 |
+
self.out_channels = out_channels
|
1338 |
+
self.add_bias = add_bias
|
1339 |
+
self.init_bias = init_bias
|
1340 |
+
self.use_wscale = use_wscale
|
1341 |
+
self.wscale_gain = wscale_gain
|
1342 |
+
self.lr_mul = lr_mul
|
1343 |
+
self.activation_type = activation_type
|
1344 |
+
|
1345 |
+
weight_shape = (out_channels, in_channels)
|
1346 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
1347 |
+
if use_wscale:
|
1348 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
1349 |
+
self.wscale = wscale * lr_mul
|
1350 |
+
else:
|
1351 |
+
self.weight = nn.Parameter(
|
1352 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
1353 |
+
self.wscale = lr_mul
|
1354 |
+
|
1355 |
+
if add_bias:
|
1356 |
+
init_bias = np.float32(init_bias) / lr_mul
|
1357 |
+
self.bias = nn.Parameter(torch.full([out_channels], init_bias))
|
1358 |
+
self.bscale = lr_mul
|
1359 |
+
else:
|
1360 |
+
self.bias = None
|
1361 |
+
|
1362 |
+
def extra_repr(self):
|
1363 |
+
return (f'in_ch={self.in_channels}, '
|
1364 |
+
f'out_ch={self.out_channels}, '
|
1365 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
1366 |
+
f'bias={self.add_bias}, '
|
1367 |
+
f'init_bias={self.init_bias}, '
|
1368 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
1369 |
+
f'act={self.activation_type}')
|
1370 |
+
|
1371 |
+
def forward(self, x, impl='cuda'):
|
1372 |
+
dtype = x.dtype
|
1373 |
+
|
1374 |
+
if x.ndim != 2:
|
1375 |
+
x = x.flatten(start_dim=1)
|
1376 |
+
|
1377 |
+
weight = self.weight.to(dtype) * self.wscale
|
1378 |
+
bias = None
|
1379 |
+
if self.bias is not None:
|
1380 |
+
bias = self.bias.to(dtype)
|
1381 |
+
if self.bscale != 1.0:
|
1382 |
+
bias = bias * self.bscale
|
1383 |
+
|
1384 |
+
# Fast pass for linear activation.
|
1385 |
+
if self.activation_type == 'linear' and bias is not None:
|
1386 |
+
x = torch.addmm(bias.unsqueeze(0), x, weight.t())
|
1387 |
+
else:
|
1388 |
+
x = x.matmul(weight.t())
|
1389 |
+
x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl)
|
1390 |
+
|
1391 |
+
assert x.dtype == dtype
|
1392 |
+
return x
|
1393 |
+
|
1394 |
+
# pylint: enable=missing-function-docstring
|
models/stylegan3_generator.py
ADDED
@@ -0,0 +1,1332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of generator described in StyleGAN3.
|
3 |
+
|
4 |
+
Compared to that of StyleGAN2, the generator in StyleGAN3 controls the frequency
|
5 |
+
flow along with the convolutional layers growing.
|
6 |
+
|
7 |
+
Paper: https://arxiv.org/pdf/2106.12423.pdf
|
8 |
+
|
9 |
+
Official implementation: https://github.com/NVlabs/stylegan3
|
10 |
+
"""
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import scipy.signal
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from third_party.stylegan3_official_ops import bias_act
|
20 |
+
from third_party.stylegan3_official_ops import filtered_lrelu
|
21 |
+
from third_party.stylegan3_official_ops import conv2d_gradfix
|
22 |
+
from .utils.ops import all_gather
|
23 |
+
|
24 |
+
__all__ = ['StyleGAN3Generator']
|
25 |
+
|
26 |
+
# Resolutions allowed.
|
27 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
28 |
+
|
29 |
+
# pylint: disable=missing-function-docstring
|
30 |
+
|
31 |
+
class StyleGAN3Generator(nn.Module):
|
32 |
+
"""Defines the generator network in StyleGAN3.
|
33 |
+
|
34 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
35 |
+
[-1, 1].
|
36 |
+
|
37 |
+
Settings for the mapping network:
|
38 |
+
|
39 |
+
(1) z_dim: Dimension of the input latent space, Z. (default: 512)
|
40 |
+
(2) w_dim: Dimension of the output latent space, W. (default: 512)
|
41 |
+
(3) repeat_w: Repeat w-code for different layers. (default: True)
|
42 |
+
(4) normalize_z: Whether to normalize the z-code. (default: True)
|
43 |
+
(5) mapping_layers: Number of layers of the mapping network. (default: 2)
|
44 |
+
(6) mapping_fmaps: Number of hidden channels of the mapping network.
|
45 |
+
(default: 512)
|
46 |
+
(7) mapping_lr_mul: Learning rate multiplier for the mapping network.
|
47 |
+
(default: 0.01)
|
48 |
+
|
49 |
+
Settings for conditional generation:
|
50 |
+
|
51 |
+
(1) label_dim: Dimension of the additional label for conditional generation.
|
52 |
+
In one-hot conditioning case, it is equal to the number of classes. If
|
53 |
+
set to 0, conditioning training will be disabled. (default: 0)
|
54 |
+
(2) embedding_dim: Dimension of the embedding space, if needed.
|
55 |
+
(default: 512)
|
56 |
+
(3) embedding_bias: Whether to add bias to embedding learning.
|
57 |
+
(default: True)
|
58 |
+
(4) embedding_lr_mul: Learning rate multiplier for the embedding learning.
|
59 |
+
(default: 1.0)
|
60 |
+
(5) normalize_embedding: Whether to normalize the embedding. (default: True)
|
61 |
+
(6) normalize_embedding_latent: Whether to normalize the embedding together
|
62 |
+
with the latent. (default: False)
|
63 |
+
|
64 |
+
Settings for the synthesis network:
|
65 |
+
|
66 |
+
(1) resolution: The resolution of the output image. (default: -1)
|
67 |
+
(2) image_channels: Number of channels of the output image. (default: 3)
|
68 |
+
(3) final_tanh: Whether to use `tanh` to control the final pixel range.
|
69 |
+
(default: False)
|
70 |
+
(4) output_scale: Factor to scaling the output image. (default: 0.25)
|
71 |
+
(5) num_layers: Number of synthesis layers, excluding the first positional
|
72 |
+
encoding layer and the last ToRGB layer. (default: 14)
|
73 |
+
(6) num_critical: Number of synthesis layers with critical sampling. These
|
74 |
+
layers are always set as top (with highest resolution) ones.
|
75 |
+
(7) fmaps_base: Factor to control number of feature maps for each layer.
|
76 |
+
(default: 32 << 10)
|
77 |
+
(8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
78 |
+
(9) kernel_size: Size of convolutional kernels. (default: 1)
|
79 |
+
(10) conv_clamp: A threshold to clamp the output of convolution layers to
|
80 |
+
avoid overflow under FP16 training. (default: None)
|
81 |
+
(11) first_cutoff: Cutoff frequency of the first layer. (default: 2)
|
82 |
+
(12) first_stopband: Stopband of the first layer. (default: 2 ** 2.1)
|
83 |
+
(13) last_stopband_rel: Stopband of the last layer, relative to the last
|
84 |
+
cutoff, which is `resolution / 2`. Concretely, `last_stopband` will be
|
85 |
+
equal to `resolution / 2 * last_stopband_rel`. (default: 2 ** 0.3)
|
86 |
+
(14) margin_size: Size of margin for each feature map. (default: 10)
|
87 |
+
(15) filter_size: Size of filter for upsampling and downsampling around the
|
88 |
+
activation. (default: 6)
|
89 |
+
(16) act_upsampling: Factor used to upsample the feature map before
|
90 |
+
activation for anti-aliasing. (default: 2)
|
91 |
+
(17) use_radial_filter: Whether to use radial filter for downsampling after
|
92 |
+
the activation. (default: False)
|
93 |
+
(18) eps: A small value to avoid divide overflow. (default: 1e-8)
|
94 |
+
|
95 |
+
Runtime settings:
|
96 |
+
|
97 |
+
(1) w_moving_decay: Decay factor for updating `w_avg`, which is used for
|
98 |
+
training only. Set `None` to disable. (default: 0.998)
|
99 |
+
(2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set
|
100 |
+
as `True`, the stats will be more accurate, yet the speed maybe a little
|
101 |
+
bit slower. (default: False)
|
102 |
+
(3) style_mixing_prob: Probability to perform style mixing as a training
|
103 |
+
regularization. Set `None` to disable. (default: None)
|
104 |
+
(4) trunc_psi: Truncation psi, set `None` to disable. (default: None)
|
105 |
+
(5) trunc_layers: Number of layers to perform truncation. (default: None)
|
106 |
+
(6) magnitude_moving_decay: Decay factor for updating `magnitude_ema` in
|
107 |
+
each `SynthesisLayer`, which is used for training only. Set `None` to
|
108 |
+
disable. (default: 0.999)
|
109 |
+
(7) update_ema: Whether to update `w_avg` in the `MappingNetwork` and
|
110 |
+
`magnitude_ema` in each `SynthesisLayer`. This field only takes effect
|
111 |
+
in `training` model. (default: False)
|
112 |
+
(8) fp16_res: Layers at resolution higher than (or equal to) this field will
|
113 |
+
use `float16` precision for computation. This is merely used for
|
114 |
+
acceleration. If set as `None`, all layers will use `float32` by
|
115 |
+
default. (default: None)
|
116 |
+
(9) impl: Implementation mode of some particular ops, e.g., `filtering`,
|
117 |
+
`bias_act`, etc. `cuda` means using the official CUDA implementation
|
118 |
+
from StyleGAN3, while `ref` means using the native PyTorch ops.
|
119 |
+
(default: `cuda`)
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(self,
|
123 |
+
# Settings for mapping network.
|
124 |
+
z_dim=512,
|
125 |
+
w_dim=512,
|
126 |
+
repeat_w=True,
|
127 |
+
normalize_z=True,
|
128 |
+
mapping_layers=2,
|
129 |
+
mapping_fmaps=512,
|
130 |
+
mapping_lr_mul=0.01,
|
131 |
+
# Settings for conditional generation.
|
132 |
+
label_dim=0,
|
133 |
+
embedding_dim=512,
|
134 |
+
embedding_bias=True,
|
135 |
+
embedding_lr_mul=1.0,
|
136 |
+
normalize_embedding=True,
|
137 |
+
normalize_embedding_latent=False,
|
138 |
+
# Settings for synthesis network.
|
139 |
+
resolution=-1,
|
140 |
+
image_channels=3,
|
141 |
+
final_tanh=False,
|
142 |
+
output_scale=0.25,
|
143 |
+
num_layers=14,
|
144 |
+
num_critical=2,
|
145 |
+
fmaps_base=32 << 10,
|
146 |
+
fmaps_max=512,
|
147 |
+
kernel_size=1,
|
148 |
+
conv_clamp=256,
|
149 |
+
first_cutoff=2,
|
150 |
+
first_stopband=2 ** 2.1,
|
151 |
+
last_stopband_rel=2 ** 0.3,
|
152 |
+
margin_size=10,
|
153 |
+
filter_size=6,
|
154 |
+
act_upsampling=2,
|
155 |
+
use_radial_filter=False,
|
156 |
+
eps=1e-8):
|
157 |
+
"""Initializes with basic settings."""
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
161 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
162 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
163 |
+
|
164 |
+
self.z_dim = z_dim
|
165 |
+
self.w_dim = w_dim
|
166 |
+
self.repeat_w = repeat_w
|
167 |
+
self.normalize_z = normalize_z
|
168 |
+
self.mapping_layers = mapping_layers
|
169 |
+
self.mapping_fmaps = mapping_fmaps
|
170 |
+
self.mapping_lr_mul = mapping_lr_mul
|
171 |
+
|
172 |
+
self.label_dim = label_dim
|
173 |
+
self.embedding_dim = embedding_dim
|
174 |
+
self.embedding_bias = embedding_bias
|
175 |
+
self.embedding_lr_mul = embedding_lr_mul
|
176 |
+
self.normalize_embedding = normalize_embedding
|
177 |
+
self.normalize_embedding_latent = normalize_embedding_latent
|
178 |
+
|
179 |
+
self.resolution = resolution
|
180 |
+
self.image_channels = image_channels
|
181 |
+
self.final_tanh = final_tanh
|
182 |
+
self.output_scale = output_scale
|
183 |
+
self.num_layers = num_layers + 2 # Including InputLayer and ToRGBLayer.
|
184 |
+
self.num_critical = num_critical
|
185 |
+
self.fmaps_base = fmaps_base
|
186 |
+
self.fmaps_max = fmaps_max
|
187 |
+
self.kernel_size = kernel_size
|
188 |
+
self.conv_clamp = conv_clamp
|
189 |
+
self.first_cutoff = first_cutoff
|
190 |
+
self.first_stopband = first_stopband
|
191 |
+
self.last_stopband_rel = last_stopband_rel
|
192 |
+
self.margin_size = margin_size
|
193 |
+
self.filter_size = filter_size
|
194 |
+
self.act_upsampling = act_upsampling
|
195 |
+
self.use_radial_filter = use_radial_filter
|
196 |
+
self.eps = eps
|
197 |
+
|
198 |
+
# Dimension of latent space, which is convenient for sampling.
|
199 |
+
self.latent_dim = (z_dim,)
|
200 |
+
|
201 |
+
self.mapping = MappingNetwork(
|
202 |
+
input_dim=z_dim,
|
203 |
+
output_dim=w_dim,
|
204 |
+
num_outputs=self.num_layers,
|
205 |
+
repeat_output=repeat_w,
|
206 |
+
normalize_input=normalize_z,
|
207 |
+
num_layers=mapping_layers,
|
208 |
+
hidden_dim=mapping_fmaps,
|
209 |
+
lr_mul=mapping_lr_mul,
|
210 |
+
label_dim=label_dim,
|
211 |
+
embedding_dim=embedding_dim,
|
212 |
+
embedding_bias=embedding_bias,
|
213 |
+
embedding_lr_mul=embedding_lr_mul,
|
214 |
+
normalize_embedding=normalize_embedding,
|
215 |
+
normalize_embedding_latent=normalize_embedding_latent,
|
216 |
+
eps=eps)
|
217 |
+
|
218 |
+
# This is used for truncation trick.
|
219 |
+
if self.repeat_w:
|
220 |
+
self.register_buffer('w_avg', torch.zeros(w_dim))
|
221 |
+
else:
|
222 |
+
self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim))
|
223 |
+
|
224 |
+
self.synthesis = SynthesisNetwork(resolution=resolution,
|
225 |
+
w_dim=w_dim,
|
226 |
+
image_channels=image_channels,
|
227 |
+
final_tanh=final_tanh,
|
228 |
+
output_scale=output_scale,
|
229 |
+
num_layers=num_layers,
|
230 |
+
num_critical=num_critical,
|
231 |
+
fmaps_base=fmaps_base,
|
232 |
+
fmaps_max=fmaps_max,
|
233 |
+
kernel_size=kernel_size,
|
234 |
+
conv_clamp=conv_clamp,
|
235 |
+
first_cutoff=first_cutoff,
|
236 |
+
first_stopband=first_stopband,
|
237 |
+
last_stopband_rel=last_stopband_rel,
|
238 |
+
margin_size=margin_size,
|
239 |
+
filter_size=filter_size,
|
240 |
+
act_upsampling=act_upsampling,
|
241 |
+
use_radial_filter=use_radial_filter,
|
242 |
+
eps=eps)
|
243 |
+
|
244 |
+
self.var_mapping = {'w_avg': 'mapping.w_avg'}
|
245 |
+
for key, val in self.mapping.var_mapping.items():
|
246 |
+
self.var_mapping[f'mapping.{key}'] = f'mapping.{val}'
|
247 |
+
for key, val in self.synthesis.var_mapping.items():
|
248 |
+
self.var_mapping[f'synthesis.{key}'] = f'synthesis.{val}'
|
249 |
+
|
250 |
+
def set_space_of_latent(self, space_of_latent):
|
251 |
+
"""Sets the space to which the latent code belong.
|
252 |
+
|
253 |
+
See `SynthesisNetwork` for more details.
|
254 |
+
"""
|
255 |
+
self.synthesis.set_space_of_latent(space_of_latent)
|
256 |
+
|
257 |
+
def forward(self,
|
258 |
+
z,
|
259 |
+
label=None,
|
260 |
+
w_moving_decay=0.998,
|
261 |
+
sync_w_avg=False,
|
262 |
+
style_mixing_prob=None,
|
263 |
+
trunc_psi=None,
|
264 |
+
trunc_layers=None,
|
265 |
+
magnitude_moving_decay=0.999,
|
266 |
+
update_ema=False,
|
267 |
+
fp16_res=None,
|
268 |
+
impl='cuda'):
|
269 |
+
"""Connects mapping network and synthesis network.
|
270 |
+
|
271 |
+
This forward function will also update the average `w_code`, perform
|
272 |
+
style mixing as a training regularizer, and do truncation trick, which
|
273 |
+
is specially designed for inference.
|
274 |
+
|
275 |
+
Concretely, the truncation trick acts as follows:
|
276 |
+
|
277 |
+
For layers in range [0, truncation_layers), the truncated w-code is
|
278 |
+
computed as
|
279 |
+
|
280 |
+
w_new = w_avg + (w - w_avg) * truncation_psi
|
281 |
+
|
282 |
+
To disable truncation, please set
|
283 |
+
|
284 |
+
(1) truncation_psi = 1.0 (None) OR
|
285 |
+
(2) truncation_layers = 0 (None)
|
286 |
+
"""
|
287 |
+
|
288 |
+
mapping_results = self.mapping(z, label, impl=impl)
|
289 |
+
|
290 |
+
w = mapping_results['w']
|
291 |
+
if self.training and update_ema and w_moving_decay is not None:
|
292 |
+
if sync_w_avg:
|
293 |
+
batch_w_avg = all_gather(w.detach()).mean(dim=0)
|
294 |
+
else:
|
295 |
+
batch_w_avg = w.detach().mean(dim=0)
|
296 |
+
self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay))
|
297 |
+
|
298 |
+
wp = mapping_results.pop('wp')
|
299 |
+
if self.training and style_mixing_prob is not None:
|
300 |
+
if np.random.uniform() < style_mixing_prob:
|
301 |
+
new_z = torch.randn_like(z)
|
302 |
+
new_wp = self.mapping(new_z, label, impl=impl)['wp']
|
303 |
+
mixing_cutoff = np.random.randint(1, self.num_layers)
|
304 |
+
wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:]
|
305 |
+
|
306 |
+
if not self.training:
|
307 |
+
trunc_psi = 1.0 if trunc_psi is None else trunc_psi
|
308 |
+
trunc_layers = 0 if trunc_layers is None else trunc_layers
|
309 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
310 |
+
w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers]
|
311 |
+
wp[:, :trunc_layers] = w_avg.lerp(
|
312 |
+
wp[:, :trunc_layers], trunc_psi)
|
313 |
+
|
314 |
+
synthesis_results = self.synthesis(
|
315 |
+
wp,
|
316 |
+
magnitude_moving_decay=magnitude_moving_decay,
|
317 |
+
update_ema=update_ema,
|
318 |
+
fp16_res=fp16_res,
|
319 |
+
impl=impl)
|
320 |
+
|
321 |
+
return {**mapping_results, **synthesis_results}
|
322 |
+
|
323 |
+
|
324 |
+
class MappingNetwork(nn.Module):
|
325 |
+
"""Implements the latent space mapping network.
|
326 |
+
|
327 |
+
Basically, this network executes several dense layers in sequence, and the
|
328 |
+
label embedding if needed.
|
329 |
+
"""
|
330 |
+
|
331 |
+
def __init__(self,
|
332 |
+
input_dim,
|
333 |
+
output_dim,
|
334 |
+
num_outputs,
|
335 |
+
repeat_output,
|
336 |
+
normalize_input,
|
337 |
+
num_layers,
|
338 |
+
hidden_dim,
|
339 |
+
lr_mul,
|
340 |
+
label_dim,
|
341 |
+
embedding_dim,
|
342 |
+
embedding_bias,
|
343 |
+
embedding_lr_mul,
|
344 |
+
normalize_embedding,
|
345 |
+
normalize_embedding_latent,
|
346 |
+
eps):
|
347 |
+
super().__init__()
|
348 |
+
|
349 |
+
self.input_dim = input_dim
|
350 |
+
self.output_dim = output_dim
|
351 |
+
self.num_outputs = num_outputs
|
352 |
+
self.repeat_output = repeat_output
|
353 |
+
self.normalize_input = normalize_input
|
354 |
+
self.num_layers = num_layers
|
355 |
+
self.hidden_dim = hidden_dim
|
356 |
+
self.lr_mul = lr_mul
|
357 |
+
self.label_dim = label_dim
|
358 |
+
self.embedding_dim = embedding_dim
|
359 |
+
self.embedding_bias = embedding_bias
|
360 |
+
self.embedding_lr_mul = embedding_lr_mul
|
361 |
+
self.normalize_embedding = normalize_embedding
|
362 |
+
self.normalize_embedding_latent = normalize_embedding_latent
|
363 |
+
self.eps = eps
|
364 |
+
|
365 |
+
self.var_mapping = {}
|
366 |
+
|
367 |
+
self.norm = PixelNormLayer(dim=1, eps=eps)
|
368 |
+
|
369 |
+
if self.label_dim > 0:
|
370 |
+
input_dim = input_dim + embedding_dim
|
371 |
+
self.embedding = DenseLayer(in_channels=label_dim,
|
372 |
+
out_channels=embedding_dim,
|
373 |
+
init_weight_std=1.0,
|
374 |
+
add_bias=embedding_bias,
|
375 |
+
init_bias=0.0,
|
376 |
+
lr_mul=embedding_lr_mul,
|
377 |
+
activation_type='linear')
|
378 |
+
self.var_mapping['embedding.weight'] = 'embed.weight'
|
379 |
+
if self.embedding_bias:
|
380 |
+
self.var_mapping['embedding.bias'] = 'embed.bias'
|
381 |
+
|
382 |
+
if num_outputs is not None and not repeat_output:
|
383 |
+
output_dim = output_dim * num_outputs
|
384 |
+
for i in range(num_layers):
|
385 |
+
in_channels = (input_dim if i == 0 else hidden_dim)
|
386 |
+
out_channels = (output_dim if i == (num_layers - 1) else hidden_dim)
|
387 |
+
self.add_module(f'dense{i}',
|
388 |
+
DenseLayer(in_channels=in_channels,
|
389 |
+
out_channels=out_channels,
|
390 |
+
init_weight_std=1.0,
|
391 |
+
add_bias=True,
|
392 |
+
init_bias=0.0,
|
393 |
+
lr_mul=lr_mul,
|
394 |
+
activation_type='lrelu'))
|
395 |
+
self.var_mapping[f'dense{i}.weight'] = f'fc{i}.weight'
|
396 |
+
self.var_mapping[f'dense{i}.bias'] = f'fc{i}.bias'
|
397 |
+
|
398 |
+
def forward(self, z, label=None, impl='cuda'):
|
399 |
+
if z.ndim != 2 or z.shape[1] != self.input_dim:
|
400 |
+
raise ValueError(f'Input latent code should be with shape '
|
401 |
+
f'[batch_size, input_dim], where '
|
402 |
+
f'`input_dim` equals to {self.input_dim}!\n'
|
403 |
+
f'But `{z.shape}` is received!')
|
404 |
+
if self.normalize_input:
|
405 |
+
z = self.norm(z)
|
406 |
+
|
407 |
+
if self.label_dim > 0:
|
408 |
+
if label is None:
|
409 |
+
raise ValueError(f'Model requires an additional label '
|
410 |
+
f'(with dimension {self.label_dim}) as input, '
|
411 |
+
f'but no label is received!')
|
412 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
|
413 |
+
raise ValueError(f'Input label should be with shape '
|
414 |
+
f'[batch_size, label_dim], where '
|
415 |
+
f'`batch_size` equals to that of '
|
416 |
+
f'latent codes ({z.shape[0]}) and '
|
417 |
+
f'`label_dim` equals to {self.label_dim}!\n'
|
418 |
+
f'But `{label.shape}` is received!')
|
419 |
+
label = label.to(dtype=torch.float32)
|
420 |
+
embedding = self.embedding(label, impl=impl)
|
421 |
+
if self.normalize_embedding:
|
422 |
+
embedding = self.norm(embedding)
|
423 |
+
w = torch.cat((z, embedding), dim=1)
|
424 |
+
else:
|
425 |
+
w = z
|
426 |
+
|
427 |
+
if self.label_dim > 0 and self.normalize_embedding_latent:
|
428 |
+
w = self.norm(w)
|
429 |
+
|
430 |
+
for i in range(self.num_layers):
|
431 |
+
w = getattr(self, f'dense{i}')(w, impl=impl)
|
432 |
+
|
433 |
+
wp = None
|
434 |
+
if self.num_outputs is not None:
|
435 |
+
if self.repeat_output:
|
436 |
+
wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1))
|
437 |
+
else:
|
438 |
+
wp = w.reshape(-1, self.num_outputs, self.output_dim)
|
439 |
+
|
440 |
+
results = {
|
441 |
+
'z': z,
|
442 |
+
'label': label,
|
443 |
+
'w': w,
|
444 |
+
'wp': wp,
|
445 |
+
}
|
446 |
+
if self.label_dim > 0:
|
447 |
+
results['embedding'] = embedding
|
448 |
+
return results
|
449 |
+
|
450 |
+
|
451 |
+
class SynthesisNetwork(nn.Module):
|
452 |
+
"""Implements the image synthesis network.
|
453 |
+
|
454 |
+
Basically, this network executes several convolutional layers in sequence.
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self,
|
458 |
+
resolution,
|
459 |
+
w_dim,
|
460 |
+
image_channels,
|
461 |
+
final_tanh,
|
462 |
+
output_scale,
|
463 |
+
num_layers,
|
464 |
+
num_critical,
|
465 |
+
fmaps_base,
|
466 |
+
fmaps_max,
|
467 |
+
kernel_size,
|
468 |
+
conv_clamp,
|
469 |
+
first_cutoff,
|
470 |
+
first_stopband,
|
471 |
+
last_stopband_rel,
|
472 |
+
margin_size,
|
473 |
+
filter_size,
|
474 |
+
act_upsampling,
|
475 |
+
use_radial_filter,
|
476 |
+
eps):
|
477 |
+
super().__init__()
|
478 |
+
|
479 |
+
self.resolution = resolution
|
480 |
+
self.w_dim = w_dim
|
481 |
+
self.image_channels = image_channels
|
482 |
+
self.final_tanh = final_tanh
|
483 |
+
self.output_scale = output_scale
|
484 |
+
self.num_layers = num_layers
|
485 |
+
self.num_critical = num_critical
|
486 |
+
self.fmaps_base = fmaps_base
|
487 |
+
self.fmaps_max = fmaps_max
|
488 |
+
self.kernel_size = kernel_size
|
489 |
+
self.conv_clamp = conv_clamp
|
490 |
+
self.first_cutoff = first_cutoff
|
491 |
+
self.first_stopband = first_stopband
|
492 |
+
self.last_stopband_rel = last_stopband_rel
|
493 |
+
self.margin_size = margin_size
|
494 |
+
self.filter_size = filter_size
|
495 |
+
self.act_upsampling = act_upsampling
|
496 |
+
self.use_radial_filter = use_radial_filter
|
497 |
+
self.eps = eps
|
498 |
+
|
499 |
+
self.var_mapping = {}
|
500 |
+
|
501 |
+
# Get layer settings.
|
502 |
+
last_cutoff = resolution / 2
|
503 |
+
last_stopband = last_cutoff * last_stopband_rel
|
504 |
+
layer_indices = np.arange(num_layers + 1)
|
505 |
+
exponents = np.minimum(layer_indices / (num_layers - num_critical), 1)
|
506 |
+
cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents
|
507 |
+
stopbands = (
|
508 |
+
first_stopband * (last_stopband / first_stopband) ** exponents)
|
509 |
+
sampling_rates = np.exp2(np.ceil(np.log2(
|
510 |
+
np.minimum(stopbands * 2, self.resolution))))
|
511 |
+
sampling_rates = np.int64(sampling_rates)
|
512 |
+
half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs
|
513 |
+
sizes = sampling_rates + margin_size * 2
|
514 |
+
sizes[-2:] = resolution
|
515 |
+
sizes = np.int64(sizes)
|
516 |
+
channels = np.rint(np.minimum((fmaps_base / 2) / cutoffs, fmaps_max))
|
517 |
+
channels[-1] = image_channels
|
518 |
+
channels = np.int64(channels)
|
519 |
+
|
520 |
+
self.cutoffs = cutoffs
|
521 |
+
self.stopbands = stopbands
|
522 |
+
self.sampling_rates = sampling_rates
|
523 |
+
self.half_widths = half_widths
|
524 |
+
self.sizes = sizes
|
525 |
+
self.channels = channels
|
526 |
+
|
527 |
+
# Input layer, with positional encoding.
|
528 |
+
self.early_layer = InputLayer(w_dim=w_dim,
|
529 |
+
channels=channels[0],
|
530 |
+
size=sizes[0],
|
531 |
+
sampling_rate=sampling_rates[0],
|
532 |
+
cutoff=cutoffs[0])
|
533 |
+
self.var_mapping['early_layer.weight'] = 'input.weight'
|
534 |
+
self.var_mapping['early_layer.affine.weight'] = 'input.affine.weight'
|
535 |
+
self.var_mapping['early_layer.affine.bias'] = 'input.affine.bias'
|
536 |
+
self.var_mapping['early_layer.transform'] = 'input.transform'
|
537 |
+
self.var_mapping['early_layer.frequency'] = 'input.freqs'
|
538 |
+
self.var_mapping['early_layer.phase'] = 'input.phases'
|
539 |
+
|
540 |
+
# Convolutional layers.
|
541 |
+
for idx in range(num_layers + 1):
|
542 |
+
# Position related settings.
|
543 |
+
if idx < num_layers:
|
544 |
+
kernel_size = self.kernel_size
|
545 |
+
demodulate = True
|
546 |
+
act_upsampling = self.act_upsampling
|
547 |
+
else: # ToRGB layer.
|
548 |
+
kernel_size = 1
|
549 |
+
demodulate = False
|
550 |
+
act_upsampling = 1
|
551 |
+
if idx < num_layers - num_critical: # Non-critical sampling.
|
552 |
+
use_radial_filter = self.use_radial_filter
|
553 |
+
else: # Critical sampling.
|
554 |
+
use_radial_filter = False
|
555 |
+
|
556 |
+
prev_idx = max(idx - 1, 0)
|
557 |
+
layer_name = f'layer{idx}'
|
558 |
+
official_layer_name = f'L{idx}_{sizes[idx]}_{channels[idx]}'
|
559 |
+
self.add_module(
|
560 |
+
layer_name,
|
561 |
+
SynthesisLayer(in_channels=channels[prev_idx],
|
562 |
+
out_channels=channels[idx],
|
563 |
+
w_dim=w_dim,
|
564 |
+
kernel_size=kernel_size,
|
565 |
+
demodulate=demodulate,
|
566 |
+
eps=eps,
|
567 |
+
conv_clamp=conv_clamp,
|
568 |
+
in_size=sizes[prev_idx],
|
569 |
+
out_size=sizes[idx],
|
570 |
+
in_sampling_rate=sampling_rates[prev_idx],
|
571 |
+
out_sampling_rate=sampling_rates[idx],
|
572 |
+
in_cutoff=cutoffs[prev_idx],
|
573 |
+
out_cutoff=cutoffs[idx],
|
574 |
+
in_half_width=half_widths[prev_idx],
|
575 |
+
out_half_width=half_widths[idx],
|
576 |
+
filter_size=filter_size,
|
577 |
+
use_radial_filter=use_radial_filter,
|
578 |
+
act_upsampling=act_upsampling))
|
579 |
+
|
580 |
+
self.var_mapping[f'{layer_name}.magnitude_ema'] = (
|
581 |
+
f'{official_layer_name}.magnitude_ema')
|
582 |
+
self.var_mapping[f'{layer_name}.conv.weight'] = (
|
583 |
+
f'{official_layer_name}.weight')
|
584 |
+
self.var_mapping[f'{layer_name}.conv.style.weight'] = (
|
585 |
+
f'{official_layer_name}.affine.weight')
|
586 |
+
self.var_mapping[f'{layer_name}.conv.style.bias'] = (
|
587 |
+
f'{official_layer_name}.affine.bias')
|
588 |
+
self.var_mapping[f'{layer_name}.filter.bias'] = (
|
589 |
+
f'{official_layer_name}.bias')
|
590 |
+
if idx < num_layers: # ToRGB layer does not need filters.
|
591 |
+
self.var_mapping[f'{layer_name}.filter.up_filter'] = (
|
592 |
+
f'{official_layer_name}.up_filter')
|
593 |
+
self.var_mapping[f'{layer_name}.filter.down_filter'] = (
|
594 |
+
f'{official_layer_name}.down_filter')
|
595 |
+
|
596 |
+
def set_space_of_latent(self, space_of_latent):
|
597 |
+
"""Sets the space to which the latent code belong.
|
598 |
+
|
599 |
+
This function is particularly used for choosing how to inject the latent
|
600 |
+
code into the convolutional layers. The original generator will take a
|
601 |
+
W-Space code and apply it for style modulation after an affine
|
602 |
+
transformation. But, sometimes, it may need to directly feed an already
|
603 |
+
affine-transformed code into the convolutional layer, e.g., when
|
604 |
+
training an encoder for GAN inversion. We term the transformed space as
|
605 |
+
Style Space (or Y-Space). This function is designed to tell the
|
606 |
+
convolutional layers how to use the input code.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
space_of_latent: The space to which the latent code belong. Case
|
610 |
+
insensitive. Support `W` and `Y`.
|
611 |
+
"""
|
612 |
+
space_of_latent = space_of_latent.upper()
|
613 |
+
for module in self.modules():
|
614 |
+
if isinstance(module, ModulateConvLayer):
|
615 |
+
setattr(module, 'space_of_latent', space_of_latent)
|
616 |
+
|
617 |
+
def forward(self,
|
618 |
+
wp,
|
619 |
+
magnitude_moving_decay=0.999,
|
620 |
+
update_ema=False,
|
621 |
+
fp16_res=None,
|
622 |
+
impl='cuda'):
|
623 |
+
results = {'wp': wp}
|
624 |
+
|
625 |
+
x = self.early_layer(wp[:, 0])
|
626 |
+
for idx, sampling_rate in enumerate(self.sampling_rates):
|
627 |
+
if fp16_res is not None and sampling_rate >= fp16_res:
|
628 |
+
x = x.to(torch.float16)
|
629 |
+
layer = getattr(self, f'layer{idx}')
|
630 |
+
x, style = layer(x, wp[:, idx + 1],
|
631 |
+
magnitude_moving_decay=magnitude_moving_decay,
|
632 |
+
update_ema=update_ema,
|
633 |
+
impl=impl)
|
634 |
+
results[f'style{idx}'] = style
|
635 |
+
|
636 |
+
if self.output_scale != 1:
|
637 |
+
x = x * self.output_scale
|
638 |
+
x = x.to(torch.float32)
|
639 |
+
if self.final_tanh:
|
640 |
+
x = torch.tanh(x)
|
641 |
+
results['image'] = x
|
642 |
+
return results
|
643 |
+
|
644 |
+
|
645 |
+
class PixelNormLayer(nn.Module):
|
646 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
647 |
+
|
648 |
+
def __init__(self, dim, eps):
|
649 |
+
super().__init__()
|
650 |
+
self.dim = dim
|
651 |
+
self.eps = eps
|
652 |
+
|
653 |
+
def extra_repr(self):
|
654 |
+
return f'dim={self.dim}, epsilon={self.eps}'
|
655 |
+
|
656 |
+
def forward(self, x):
|
657 |
+
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
|
658 |
+
return x * scale
|
659 |
+
|
660 |
+
|
661 |
+
class InputLayer(nn.Module):
|
662 |
+
"""Implements the input layer with positional encoding.
|
663 |
+
|
664 |
+
Basically, this block outputs a feature map with shape
|
665 |
+
`(channels, size, size)` based on the coordinate information.
|
666 |
+
`sampling_rate` and `cutoff` are used to control the coordinate range and
|
667 |
+
strength respectively.
|
668 |
+
|
669 |
+
For a low-pass filter, `cutoff` is the same as the `bandwidth`.
|
670 |
+
The initial frequency of the starting feature map is controlled by the
|
671 |
+
positional encoding `sin(2 * pi * x)`, where
|
672 |
+
`x = trans(coord) * frequency + phase`. We would like to introduce rich
|
673 |
+
information (i.e. frequencies), but keep all frequencies lower than
|
674 |
+
stopband, which is `sampling_rate / 2`.
|
675 |
+
|
676 |
+
Besides, this layer also supports learning a transformation from the latent
|
677 |
+
code w, and providing a customized transformation for inference. Please
|
678 |
+
use the buffer `transform`.
|
679 |
+
|
680 |
+
NOTE: `size` is different from `sampling_rate`. `sampling_rate` is the
|
681 |
+
actual size of the current stage, which determines the maximum frequency
|
682 |
+
that the feature maps can hold. `size` is the actual height and width of the
|
683 |
+
current feature map, including the extended border.
|
684 |
+
"""
|
685 |
+
|
686 |
+
def __init__(self, w_dim, channels, size, sampling_rate, cutoff):
|
687 |
+
super().__init__()
|
688 |
+
|
689 |
+
self.w_dim = w_dim
|
690 |
+
self.channels = channels
|
691 |
+
self.size = size
|
692 |
+
self.sampling_rate = sampling_rate
|
693 |
+
self.cutoff = cutoff
|
694 |
+
|
695 |
+
# Coordinate of the entire feature map, with resolution (size, size).
|
696 |
+
# The coordinate range for the central (sampling_rate, sampling_rate)
|
697 |
+
# region is set as (-0.0, 0.5), which extends to the remaining region.
|
698 |
+
theta = torch.eye(2, 3)
|
699 |
+
theta[0, 0] = 0.5 / sampling_rate * size
|
700 |
+
theta[1, 1] = 0.5 / sampling_rate * size
|
701 |
+
grid = F.affine_grid(theta=theta.unsqueeze(0),
|
702 |
+
size=(1, 1, size, size),
|
703 |
+
align_corners=False)
|
704 |
+
self.register_buffer('grid', grid)
|
705 |
+
|
706 |
+
# Draw random frequency from a uniform 2D disc for each channel
|
707 |
+
# regarding X and Y dimension. And also draw a random phase for each
|
708 |
+
# channel. Accordingly, each channel has three pre-defined parameters,
|
709 |
+
# which are X-frequency, Y-frequency, and phase.
|
710 |
+
frequency = torch.randn(channels, 2)
|
711 |
+
radius = frequency.square().sum(dim=1, keepdim=True).sqrt()
|
712 |
+
frequency = frequency / (radius * radius.square().exp().pow(0.25))
|
713 |
+
frequency = frequency * cutoff
|
714 |
+
self.register_buffer('frequency', frequency)
|
715 |
+
phase = torch.rand(channels) - 0.5
|
716 |
+
self.register_buffer('phase', phase)
|
717 |
+
|
718 |
+
# This layer is used to map the latent code w to transform factors,
|
719 |
+
# with order: cos(angle), sin(angle), transpose_x, transpose_y.
|
720 |
+
self.affine = DenseLayer(in_channels=w_dim,
|
721 |
+
out_channels=4,
|
722 |
+
init_weight_std=0.0,
|
723 |
+
add_bias=True,
|
724 |
+
init_bias=(1, 0, 0, 0),
|
725 |
+
lr_mul=1.0,
|
726 |
+
activation_type='linear')
|
727 |
+
|
728 |
+
# It is possible to use this buffer to customize the transform of the
|
729 |
+
# output synthesis.
|
730 |
+
self.register_buffer('transform', torch.eye(3))
|
731 |
+
|
732 |
+
# Use 1x1 conv to convert positional encoding to features.
|
733 |
+
self.weight = nn.Parameter(torch.randn(channels, channels))
|
734 |
+
self.weight_scale = 1 / np.sqrt(channels)
|
735 |
+
|
736 |
+
def extra_repr(self):
|
737 |
+
return (f'channels={self.channels}, '
|
738 |
+
f'size={self.size}, '
|
739 |
+
f'sampling_rate={self.sampling_rate}, '
|
740 |
+
f'cutoff={self.cutoff:.3f}, ')
|
741 |
+
|
742 |
+
def forward(self, w):
|
743 |
+
batch = w.shape[0]
|
744 |
+
|
745 |
+
# Get transformation matrix.
|
746 |
+
# Factor controlled by latent code.
|
747 |
+
transformation_factor = self.affine(w)
|
748 |
+
# Ensure the range of cosine and sine value (first two dimension).
|
749 |
+
_norm = transformation_factor[:, :2].norm(dim=1, keepdim=True)
|
750 |
+
transformation_factor = transformation_factor / _norm
|
751 |
+
# Rotation.
|
752 |
+
rotation = torch.eye(3, device=w.device).unsqueeze(0)
|
753 |
+
rotation = rotation.repeat((batch, 1, 1))
|
754 |
+
rotation[:, 0, 0] = transformation_factor[:, 0]
|
755 |
+
rotation[:, 0, 1] = -transformation_factor[:, 1]
|
756 |
+
rotation[:, 1, 0] = transformation_factor[:, 1]
|
757 |
+
rotation[:, 1, 1] = transformation_factor[:, 0]
|
758 |
+
# Translation.
|
759 |
+
translation = torch.eye(3, device=w.device).unsqueeze(0)
|
760 |
+
translation = translation.repeat((batch, 1, 1))
|
761 |
+
translation[:, 0, 2] = -transformation_factor[:, 2]
|
762 |
+
translation[:, 1, 2] = -transformation_factor[:, 3]
|
763 |
+
# Customized transformation.
|
764 |
+
transform = rotation @ translation @ self.transform.unsqueeze(0)
|
765 |
+
|
766 |
+
# Transform frequency and shift, which is equivalent to transforming
|
767 |
+
# the coordinate. For example, given a coordinate, X, we would like to
|
768 |
+
# first transform it with the rotation matrix, R, and the translation
|
769 |
+
# matrix, T, as X' = RX + T. Then, we will apply frequency, f, and
|
770 |
+
# phase, p, with sin(2 * pi * (fX' + p)). Natively, we have
|
771 |
+
# fX' + p = f(RX + T) + p = (fR)X + (fT + p)
|
772 |
+
frequency = self.frequency.unsqueeze(0) @ transform[:, :2, :2] # [NC2]
|
773 |
+
phase = self.frequency.unsqueeze(0) @ transform[:, :2, 2:] # [NC]
|
774 |
+
phase = phase.squeeze(2) + self.phase.unsqueeze(0) # [NC]
|
775 |
+
|
776 |
+
# Positional encoding.
|
777 |
+
x = self.grid # [NHW2]
|
778 |
+
x = x.unsqueeze(3) # [NHW12]
|
779 |
+
x = x @ frequency.transpose(1, 2).unsqueeze(1).unsqueeze(2) # [NHW1C]
|
780 |
+
x = x.squeeze(3) # [NHWC]
|
781 |
+
x = x + phase.unsqueeze(1).unsqueeze(2) # [NHWC]
|
782 |
+
x = torch.sin(2 * np.pi * x) # [NHWC]
|
783 |
+
|
784 |
+
# Dampen out-of-band frequency that may be introduced by the customized
|
785 |
+
# transform `self.transform`.
|
786 |
+
frequency_norm = frequency.norm(dim=2)
|
787 |
+
stopband = self.sampling_rate / 2
|
788 |
+
factor = (frequency_norm - self.cutoff) / (stopband - self.cutoff)
|
789 |
+
amplitude = (1 - factor).clamp(0, 1) # [NC]
|
790 |
+
x = x * amplitude.unsqueeze(1).unsqueeze(2) # [NHWC]
|
791 |
+
|
792 |
+
# Project positional encoding to features.
|
793 |
+
weight = self.weight * self.weight_scale
|
794 |
+
x = x @ weight.t()
|
795 |
+
|
796 |
+
return x.permute(0, 3, 1, 2).contiguous()
|
797 |
+
|
798 |
+
|
799 |
+
class SynthesisLayer(nn.Module):
|
800 |
+
"""Implements the synthesis layer.
|
801 |
+
|
802 |
+
Each synthesis layer (including ToRGB layer) consists of a
|
803 |
+
`ModulateConvLayer` and a `FilteringActLayer`. Besides, this layer will
|
804 |
+
trace the magnitude (norm) of the input feature map, and update the
|
805 |
+
statistic with `magnitude_moving_decay`.
|
806 |
+
"""
|
807 |
+
|
808 |
+
def __init__(self,
|
809 |
+
# Settings for modulated convolution.
|
810 |
+
in_channels,
|
811 |
+
out_channels,
|
812 |
+
w_dim,
|
813 |
+
kernel_size,
|
814 |
+
demodulate,
|
815 |
+
eps,
|
816 |
+
conv_clamp,
|
817 |
+
# Settings for filtering activation.
|
818 |
+
in_size,
|
819 |
+
out_size,
|
820 |
+
in_sampling_rate,
|
821 |
+
out_sampling_rate,
|
822 |
+
in_cutoff,
|
823 |
+
out_cutoff,
|
824 |
+
in_half_width,
|
825 |
+
out_half_width,
|
826 |
+
filter_size,
|
827 |
+
use_radial_filter,
|
828 |
+
act_upsampling):
|
829 |
+
"""Initializes with layer settings.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
in_channels: Number of channels of the input tensor.
|
833 |
+
out_channels: Number of channels of the output tensor.
|
834 |
+
w_dim: Dimension of W space for style modulation.
|
835 |
+
kernel_size: Size of the convolutional kernels.
|
836 |
+
demodulate: Whether to perform style demodulation.
|
837 |
+
eps: A small value to avoid divide overflow.
|
838 |
+
conv_clamp: A threshold to clamp the output of convolution layers to
|
839 |
+
avoid overflow under FP16 training.
|
840 |
+
in_size: Size of the input feature map, i.e., height and width.
|
841 |
+
out_size: Size of the output feature map, i.e., height and width.
|
842 |
+
in_sampling_rate: Sampling rate of the input feature map. Different
|
843 |
+
from `in_size` that includes extended border, this field
|
844 |
+
controls the actual maximum frequency that can be represented
|
845 |
+
by the feature map.
|
846 |
+
out_sampling_rate: Sampling rate of the output feature map.
|
847 |
+
in_cutoff: Cutoff frequency of the input feature map.
|
848 |
+
out_cutoff: Cutoff frequency of the output feature map.
|
849 |
+
in_half_width: Half-width of the transition band of the input
|
850 |
+
feature map.
|
851 |
+
out_half_width: Half-width of the transition band of the output
|
852 |
+
feature map.
|
853 |
+
filter_size: Size of the filter used in this layer.
|
854 |
+
use_radial_filter: Whether to use radial filter.
|
855 |
+
act_upsampling: Upsampling factor used before the activation.
|
856 |
+
`1` means do not wrap upsampling and downsampling around the
|
857 |
+
activation.
|
858 |
+
"""
|
859 |
+
super().__init__()
|
860 |
+
|
861 |
+
self.in_channels = in_channels
|
862 |
+
self.out_channels = out_channels
|
863 |
+
self.w_dim = w_dim
|
864 |
+
self.kernel_size = kernel_size
|
865 |
+
self.demodulate = demodulate
|
866 |
+
self.eps = eps
|
867 |
+
self.conv_clamp = conv_clamp
|
868 |
+
|
869 |
+
self.in_size = in_size
|
870 |
+
self.out_size = out_size
|
871 |
+
self.in_sampling_rate = in_sampling_rate
|
872 |
+
self.out_sampling_rate = out_sampling_rate
|
873 |
+
self.in_cutoff = in_cutoff
|
874 |
+
self.out_cutoff = out_cutoff
|
875 |
+
self.in_half_width = in_half_width
|
876 |
+
self.out_half_width = out_half_width
|
877 |
+
self.filter_size = filter_size
|
878 |
+
self.use_radial_filter = use_radial_filter
|
879 |
+
self.act_upsampling = act_upsampling
|
880 |
+
|
881 |
+
self.conv = ModulateConvLayer(in_channels=in_channels,
|
882 |
+
out_channels=out_channels,
|
883 |
+
w_dim=w_dim,
|
884 |
+
kernel_size=kernel_size,
|
885 |
+
demodulate=demodulate,
|
886 |
+
eps=eps)
|
887 |
+
self.register_buffer('magnitude_ema', torch.ones(()))
|
888 |
+
self.filter = FilteringActLayer(out_channels=out_channels,
|
889 |
+
in_size=in_size,
|
890 |
+
out_size=out_size,
|
891 |
+
in_sampling_rate=in_sampling_rate,
|
892 |
+
out_sampling_rate=out_sampling_rate,
|
893 |
+
in_cutoff=in_cutoff,
|
894 |
+
out_cutoff=out_cutoff,
|
895 |
+
in_half_width=in_half_width,
|
896 |
+
out_half_width=out_half_width,
|
897 |
+
filter_size=filter_size,
|
898 |
+
use_radial_filter=use_radial_filter,
|
899 |
+
conv_padding=self.conv.padding,
|
900 |
+
act_upsampling=act_upsampling)
|
901 |
+
|
902 |
+
def extra_repr(self):
|
903 |
+
return f'conv_clamp={self.conv_clamp}'
|
904 |
+
|
905 |
+
def forward(self,
|
906 |
+
x,
|
907 |
+
w,
|
908 |
+
magnitude_moving_decay=0.999,
|
909 |
+
update_ema=False,
|
910 |
+
impl='cuda'):
|
911 |
+
if self.training and update_ema and magnitude_moving_decay is not None:
|
912 |
+
magnitude = x.detach().to(torch.float32).square().mean()
|
913 |
+
self.magnitude_ema.copy_(
|
914 |
+
magnitude.lerp(self.magnitude_ema, magnitude_moving_decay))
|
915 |
+
|
916 |
+
input_gain = self.magnitude_ema.rsqrt()
|
917 |
+
x, style = self.conv(x, w, gain=input_gain, impl=impl)
|
918 |
+
if self.act_upsampling > 1:
|
919 |
+
x = self.filter(x, np.sqrt(2), 0.2, self.conv_clamp, impl=impl)
|
920 |
+
else:
|
921 |
+
x = self.filter(x, 1, 1, self.conv_clamp, impl=impl)
|
922 |
+
|
923 |
+
return x, style
|
924 |
+
|
925 |
+
|
926 |
+
class ModulateConvLayer(nn.Module):
|
927 |
+
"""Implements the convolutional layer with style modulation.
|
928 |
+
|
929 |
+
Different from the one introduced in StyleGAN2, this layer has following
|
930 |
+
changes:
|
931 |
+
|
932 |
+
(1) fusing `conv` and `style modulation` into one op by default
|
933 |
+
(2) NOT adding a noise onto the output feature map.
|
934 |
+
(3) NOT activating the feature map, which is moved to `FilteringActLayer`.
|
935 |
+
"""
|
936 |
+
|
937 |
+
def __init__(self,
|
938 |
+
in_channels,
|
939 |
+
out_channels,
|
940 |
+
w_dim,
|
941 |
+
kernel_size,
|
942 |
+
demodulate,
|
943 |
+
eps):
|
944 |
+
"""Initializes with layer settings.
|
945 |
+
|
946 |
+
Args:
|
947 |
+
in_channels: Number of channels of the input tensor.
|
948 |
+
out_channels: Number of channels of the output tensor.
|
949 |
+
w_dim: Dimension of W space for style modulation.
|
950 |
+
kernel_size: Size of the convolutional kernels.
|
951 |
+
demodulate: Whether to perform style demodulation.
|
952 |
+
eps: A small value to avoid divide overflow.
|
953 |
+
"""
|
954 |
+
super().__init__()
|
955 |
+
|
956 |
+
self.in_channels = in_channels
|
957 |
+
self.out_channels = out_channels
|
958 |
+
self.w_dim = w_dim
|
959 |
+
self.kernel_size = kernel_size
|
960 |
+
self.demodulate = demodulate
|
961 |
+
self.eps = eps
|
962 |
+
|
963 |
+
self.space_of_latent = 'W'
|
964 |
+
|
965 |
+
# Set up weight.
|
966 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
967 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
968 |
+
self.wscale = 1.0 / np.sqrt(kernel_size * kernel_size * in_channels)
|
969 |
+
self.padding = kernel_size - 1
|
970 |
+
|
971 |
+
# Set up style.
|
972 |
+
self.style = DenseLayer(in_channels=w_dim,
|
973 |
+
out_channels=in_channels,
|
974 |
+
init_weight_std=1.0,
|
975 |
+
add_bias=True,
|
976 |
+
init_bias=1.0,
|
977 |
+
lr_mul=1.0,
|
978 |
+
activation_type='linear')
|
979 |
+
|
980 |
+
def extra_repr(self):
|
981 |
+
return (f'in_ch={self.in_channels}, '
|
982 |
+
f'out_ch={self.out_channels}, '
|
983 |
+
f'ksize={self.kernel_size}, '
|
984 |
+
f'demodulate={self.demodulate}')
|
985 |
+
|
986 |
+
def forward_style(self, w, impl='cuda'):
|
987 |
+
"""Gets style code from the given input.
|
988 |
+
|
989 |
+
More specifically, if the input is from W-Space, it will be projected by
|
990 |
+
an affine transformation. If it is from the Style Space (Y-Space), no
|
991 |
+
operation is required.
|
992 |
+
|
993 |
+
NOTE: For codes from Y-Space, we use slicing to make sure the dimension
|
994 |
+
is correct, in case that the code is padded before fed into this layer.
|
995 |
+
"""
|
996 |
+
space_of_latent = self.space_of_latent.upper()
|
997 |
+
if space_of_latent == 'W':
|
998 |
+
if w.ndim != 2 or w.shape[1] != self.w_dim:
|
999 |
+
raise ValueError(f'The input tensor should be with shape '
|
1000 |
+
f'[batch_size, w_dim], where '
|
1001 |
+
f'`w_dim` equals to {self.w_dim}!\n'
|
1002 |
+
f'But `{w.shape}` is received!')
|
1003 |
+
style = self.style(w, impl=impl)
|
1004 |
+
elif space_of_latent == 'Y':
|
1005 |
+
if w.ndim != 2 or w.shape[1] < self.in_channels:
|
1006 |
+
raise ValueError(f'The input tensor should be with shape '
|
1007 |
+
f'[batch_size, y_dim], where '
|
1008 |
+
f'`y_dim` equals to {self.in_channels}!\n'
|
1009 |
+
f'But `{w.shape}` is received!')
|
1010 |
+
style = w[:, :self.in_channels]
|
1011 |
+
else:
|
1012 |
+
raise NotImplementedError(f'Not implemented `space_of_latent`: '
|
1013 |
+
f'`{space_of_latent}`!')
|
1014 |
+
return style
|
1015 |
+
|
1016 |
+
def forward(self, x, w, gain=None, impl='cuda'):
|
1017 |
+
dtype = x.dtype
|
1018 |
+
N, C, H, W = x.shape
|
1019 |
+
|
1020 |
+
# Affine on `w`.
|
1021 |
+
style = self.forward_style(w, impl=impl)
|
1022 |
+
if not self.demodulate:
|
1023 |
+
_style = style * self.wscale # Equivalent to scaling weight.
|
1024 |
+
else:
|
1025 |
+
_style = style
|
1026 |
+
|
1027 |
+
weight = self.weight
|
1028 |
+
out_ch, in_ch, kh, kw = weight.shape
|
1029 |
+
assert in_ch == C
|
1030 |
+
|
1031 |
+
# Pre-normalize inputs.
|
1032 |
+
if self.demodulate:
|
1033 |
+
weight = (weight *
|
1034 |
+
weight.square().mean(dim=(1, 2, 3), keepdim=True).rsqrt())
|
1035 |
+
_style = _style * _style.square().mean().rsqrt()
|
1036 |
+
|
1037 |
+
weight = weight.unsqueeze(0)
|
1038 |
+
weight = weight * _style.reshape(N, 1, in_ch, 1, 1) # modulation
|
1039 |
+
if self.demodulate:
|
1040 |
+
decoef = (weight.square().sum(dim=(2, 3, 4)) + self.eps).rsqrt()
|
1041 |
+
weight = weight * decoef.reshape(N, out_ch, 1, 1, 1) # demodulation
|
1042 |
+
|
1043 |
+
if gain is not None:
|
1044 |
+
gain = gain.expand(N, in_ch)
|
1045 |
+
weight = weight * gain.reshape(N, 1, in_ch, 1, 1)
|
1046 |
+
|
1047 |
+
# Fuse `conv` and `style modulation` as one op, using group convolution.
|
1048 |
+
x = x.reshape(1, N * in_ch, H, W)
|
1049 |
+
w = weight.reshape(N * out_ch, in_ch, kh, kw).to(dtype)
|
1050 |
+
x = conv2d_gradfix.conv2d(
|
1051 |
+
x, w, padding=self.padding, groups=N, impl=impl)
|
1052 |
+
x = x.reshape(N, out_ch, x.shape[2], x.shape[3])
|
1053 |
+
|
1054 |
+
assert x.dtype == dtype
|
1055 |
+
assert style.dtype == torch.float32
|
1056 |
+
return x, style
|
1057 |
+
|
1058 |
+
|
1059 |
+
class FilteringActLayer(nn.Module):
|
1060 |
+
"""Implements the activation, wrapped with upsampling and downsampling.
|
1061 |
+
|
1062 |
+
Basically, this layer executes the following operations in order:
|
1063 |
+
|
1064 |
+
(1) Apply bias.
|
1065 |
+
(2) Upsample the feature map to increase sampling rate.
|
1066 |
+
(3) Apply non-linearity as activation.
|
1067 |
+
(4) Downsample the feature map to target size.
|
1068 |
+
|
1069 |
+
This layer is mostly borrowed from the official implementation:
|
1070 |
+
|
1071 |
+
https://github.com/NVlabs/stylegan3/blob/main/training/networks_stylegan3.py
|
1072 |
+
"""
|
1073 |
+
|
1074 |
+
def __init__(self,
|
1075 |
+
out_channels,
|
1076 |
+
in_size,
|
1077 |
+
out_size,
|
1078 |
+
in_sampling_rate,
|
1079 |
+
out_sampling_rate,
|
1080 |
+
in_cutoff,
|
1081 |
+
out_cutoff,
|
1082 |
+
in_half_width,
|
1083 |
+
out_half_width,
|
1084 |
+
filter_size,
|
1085 |
+
use_radial_filter,
|
1086 |
+
conv_padding,
|
1087 |
+
act_upsampling):
|
1088 |
+
"""Initializes with layer settings.
|
1089 |
+
|
1090 |
+
Args:
|
1091 |
+
out_channels: Number of output channels, which is used for `bias`.
|
1092 |
+
in_size: Size of the input feature map, i.e., height and width.
|
1093 |
+
out_size: Size of the output feature map, i.e., height and width.
|
1094 |
+
in_sampling_rate: Sampling rate of the input feature map. Different
|
1095 |
+
from `in_size` that includes extended border, this field
|
1096 |
+
controls the actual maximum frequency that can be represented
|
1097 |
+
by the feature map.
|
1098 |
+
out_sampling_rate: Sampling rate of the output feature map.
|
1099 |
+
in_cutoff: Cutoff frequency of the input feature map.
|
1100 |
+
out_cutoff: Cutoff frequency of the output feature map.
|
1101 |
+
in_half_width: Half-width of the transition band of the input
|
1102 |
+
feature map.
|
1103 |
+
out_half_width: Half-width of the transition band of the output
|
1104 |
+
feature map.
|
1105 |
+
filter_size: Size of the filter used in this layer.
|
1106 |
+
use_radial_filter: Whether to use radial filter.
|
1107 |
+
conv_padding: The padding used in the previous convolutional layer.
|
1108 |
+
act_upsampling: Upsampling factor used before the activation.
|
1109 |
+
`1` means do not wrap upsampling and downsampling around the
|
1110 |
+
activation.
|
1111 |
+
"""
|
1112 |
+
super().__init__()
|
1113 |
+
|
1114 |
+
self.out_channels = out_channels
|
1115 |
+
self.in_size = in_size
|
1116 |
+
self.out_size = out_size
|
1117 |
+
self.in_sampling_rate = in_sampling_rate
|
1118 |
+
self.out_sampling_rate = out_sampling_rate
|
1119 |
+
self.in_cutoff = in_cutoff
|
1120 |
+
self.out_cutoff = out_cutoff
|
1121 |
+
self.in_half_width = in_half_width
|
1122 |
+
self.out_half_width = out_half_width
|
1123 |
+
self.filter_size = filter_size
|
1124 |
+
self.use_radial_filter = use_radial_filter
|
1125 |
+
self.conv_padding = conv_padding
|
1126 |
+
self.act_upsampling = act_upsampling
|
1127 |
+
|
1128 |
+
# Define bias.
|
1129 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
1130 |
+
|
1131 |
+
# This sampling rate describes the upsampled feature map before
|
1132 |
+
# activation.
|
1133 |
+
temp_sampling_rate = max(in_sampling_rate, out_sampling_rate)
|
1134 |
+
temp_sampling_rate = temp_sampling_rate * act_upsampling
|
1135 |
+
|
1136 |
+
# Design upsampling filter.
|
1137 |
+
up_factor = int(np.rint(temp_sampling_rate / in_sampling_rate))
|
1138 |
+
assert in_sampling_rate * up_factor == temp_sampling_rate
|
1139 |
+
if up_factor > 1:
|
1140 |
+
self.up_factor = up_factor
|
1141 |
+
self.up_taps = filter_size * up_factor
|
1142 |
+
else:
|
1143 |
+
self.up_factor = 1
|
1144 |
+
self.up_taps = 1 # No filtering.
|
1145 |
+
self.register_buffer(
|
1146 |
+
'up_filter',
|
1147 |
+
self.design_lowpass_filter(numtaps=self.up_taps,
|
1148 |
+
cutoff=in_cutoff,
|
1149 |
+
width=in_half_width * 2,
|
1150 |
+
fs=temp_sampling_rate,
|
1151 |
+
radial=False))
|
1152 |
+
|
1153 |
+
# Design downsampling filter.
|
1154 |
+
down_factor = int(np.rint(temp_sampling_rate / out_sampling_rate))
|
1155 |
+
assert out_sampling_rate * down_factor == temp_sampling_rate
|
1156 |
+
if down_factor > 1:
|
1157 |
+
self.down_factor = down_factor
|
1158 |
+
self.down_taps = filter_size * down_factor
|
1159 |
+
else:
|
1160 |
+
self.down_factor = 1
|
1161 |
+
self.down_taps = 1 # No filtering.
|
1162 |
+
self.register_buffer(
|
1163 |
+
'down_filter',
|
1164 |
+
self.design_lowpass_filter(numtaps=self.down_taps,
|
1165 |
+
cutoff=out_cutoff,
|
1166 |
+
width=out_half_width * 2,
|
1167 |
+
fs=temp_sampling_rate,
|
1168 |
+
radial=use_radial_filter))
|
1169 |
+
|
1170 |
+
# Compute padding.
|
1171 |
+
# Desired output size before downsampling.
|
1172 |
+
pad_total = (out_size - 1) * self.down_factor + 1
|
1173 |
+
# Input size after upsampling.
|
1174 |
+
pad_total = pad_total - (in_size + conv_padding) * self.up_factor
|
1175 |
+
# Size reduction caused by the filters.
|
1176 |
+
pad_total = pad_total + self.up_taps + self.down_taps - 2
|
1177 |
+
# Shift sample locations according to the symmetric interpretation.
|
1178 |
+
pad_lo = (pad_total + self.up_factor) // 2
|
1179 |
+
pad_hi = pad_total - pad_lo
|
1180 |
+
self.padding = list(map(int, (pad_lo, pad_hi, pad_lo, pad_hi)))
|
1181 |
+
|
1182 |
+
def extra_repr(self):
|
1183 |
+
return (f'in_size={self.in_size}, '
|
1184 |
+
f'out_size={self.out_size}, '
|
1185 |
+
f'in_srate={self.in_sampling_rate}, '
|
1186 |
+
f'out_srate={self.out_sampling_rate}, '
|
1187 |
+
f'in_cutoff={self.in_cutoff:.3f}, '
|
1188 |
+
f'out_cutoff={self.out_cutoff:.3f}, '
|
1189 |
+
f'in_half_width={self.in_half_width:.3f}, '
|
1190 |
+
f'out_half_width={self.out_half_width:.3f}, '
|
1191 |
+
f'up_factor={self.up_factor}, '
|
1192 |
+
f'up_taps={self.up_taps}, '
|
1193 |
+
f'down_factor={self.down_factor}, '
|
1194 |
+
f'down_taps={self.down_taps}, '
|
1195 |
+
f'filter_size={self.filter_size}, '
|
1196 |
+
f'radial_filter={self.use_radial_filter}, '
|
1197 |
+
f'conv_padding={self.conv_padding}, '
|
1198 |
+
f'act_upsampling={self.act_upsampling}')
|
1199 |
+
|
1200 |
+
@staticmethod
|
1201 |
+
def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
|
1202 |
+
"""Designs a low-pass filter.
|
1203 |
+
|
1204 |
+
Args:
|
1205 |
+
numtaps: Length of the filter (number of coefficients, i.e., the
|
1206 |
+
filter order + 1).
|
1207 |
+
cutoff: Cutoff frequency of the output filter.
|
1208 |
+
width: Width of the transition region.
|
1209 |
+
fs: Sampling frequency.
|
1210 |
+
radial: Whether to use radially symmetric jinc-based filter.
|
1211 |
+
(default: False)
|
1212 |
+
"""
|
1213 |
+
if numtaps == 1:
|
1214 |
+
return None
|
1215 |
+
|
1216 |
+
assert numtaps > 1
|
1217 |
+
|
1218 |
+
if not radial: # Separable Kaiser low-pass filter.
|
1219 |
+
f = scipy.signal.firwin(numtaps=numtaps,
|
1220 |
+
cutoff=cutoff,
|
1221 |
+
width=width,
|
1222 |
+
fs=fs)
|
1223 |
+
else: # Radially symmetric jinc-based filter.
|
1224 |
+
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
|
1225 |
+
r = np.hypot(*np.meshgrid(x, x))
|
1226 |
+
f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
|
1227 |
+
beta = scipy.signal.kaiser_beta(
|
1228 |
+
scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
|
1229 |
+
w = np.kaiser(numtaps, beta)
|
1230 |
+
f = f * np.outer(w, w)
|
1231 |
+
f = f / np.sum(f)
|
1232 |
+
return torch.as_tensor(f, dtype=torch.float32)
|
1233 |
+
|
1234 |
+
def forward(self, x, gain, slope, clamp, impl='cuda'):
|
1235 |
+
dtype = x.dtype
|
1236 |
+
|
1237 |
+
x = filtered_lrelu.filtered_lrelu(x=x,
|
1238 |
+
fu=self.up_filter,
|
1239 |
+
fd=self.down_filter,
|
1240 |
+
b=self.bias.to(dtype),
|
1241 |
+
up=self.up_factor,
|
1242 |
+
down=self.down_factor,
|
1243 |
+
padding=self.padding,
|
1244 |
+
gain=gain,
|
1245 |
+
slope=slope,
|
1246 |
+
clamp=clamp,
|
1247 |
+
impl=impl)
|
1248 |
+
|
1249 |
+
assert x.dtype == dtype
|
1250 |
+
return x
|
1251 |
+
|
1252 |
+
|
1253 |
+
class DenseLayer(nn.Module):
|
1254 |
+
"""Implements the dense layer."""
|
1255 |
+
|
1256 |
+
def __init__(self,
|
1257 |
+
in_channels,
|
1258 |
+
out_channels,
|
1259 |
+
init_weight_std,
|
1260 |
+
add_bias,
|
1261 |
+
init_bias,
|
1262 |
+
lr_mul,
|
1263 |
+
activation_type):
|
1264 |
+
"""Initializes with layer settings.
|
1265 |
+
|
1266 |
+
Args:
|
1267 |
+
in_channels: Number of channels of the input tensor.
|
1268 |
+
out_channels: Number of channels of the output tensor.
|
1269 |
+
init_weight_std: The initial standard deviation of weight.
|
1270 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
1271 |
+
init_bias: The initial bias value before training.
|
1272 |
+
lr_mul: Learning multiplier for both weight and bias.
|
1273 |
+
activation_type: Type of activation.
|
1274 |
+
"""
|
1275 |
+
super().__init__()
|
1276 |
+
self.in_channels = in_channels
|
1277 |
+
self.out_channels = out_channels
|
1278 |
+
self.init_weight_std = init_weight_std
|
1279 |
+
self.add_bias = add_bias
|
1280 |
+
self.init_bias = init_bias
|
1281 |
+
self.lr_mul = lr_mul
|
1282 |
+
self.activation_type = activation_type
|
1283 |
+
|
1284 |
+
weight_shape = (out_channels, in_channels)
|
1285 |
+
self.weight = nn.Parameter(
|
1286 |
+
torch.randn(*weight_shape) * init_weight_std / lr_mul)
|
1287 |
+
self.wscale = lr_mul / np.sqrt(in_channels)
|
1288 |
+
|
1289 |
+
if add_bias:
|
1290 |
+
init_bias = np.float32(np.float32(init_bias) / lr_mul)
|
1291 |
+
if isinstance(init_bias, np.float32):
|
1292 |
+
self.bias = nn.Parameter(torch.full([out_channels], init_bias))
|
1293 |
+
else:
|
1294 |
+
assert isinstance(init_bias, np.ndarray)
|
1295 |
+
self.bias = nn.Parameter(torch.from_numpy(init_bias))
|
1296 |
+
self.bscale = lr_mul
|
1297 |
+
else:
|
1298 |
+
self.bias = None
|
1299 |
+
|
1300 |
+
def extra_repr(self):
|
1301 |
+
return (f'in_ch={self.in_channels}, '
|
1302 |
+
f'out_ch={self.out_channels}, '
|
1303 |
+
f'init_weight_std={self.init_weight_std}, '
|
1304 |
+
f'bias={self.add_bias}, '
|
1305 |
+
f'init_bias={self.init_bias}, '
|
1306 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
1307 |
+
f'act={self.activation_type}')
|
1308 |
+
|
1309 |
+
def forward(self, x, impl='cuda'):
|
1310 |
+
dtype = x.dtype
|
1311 |
+
|
1312 |
+
if x.ndim != 2:
|
1313 |
+
x = x.flatten(start_dim=1)
|
1314 |
+
|
1315 |
+
weight = self.weight.to(dtype) * self.wscale
|
1316 |
+
bias = None
|
1317 |
+
if self.bias is not None:
|
1318 |
+
bias = self.bias.to(dtype)
|
1319 |
+
if self.bscale != 1.0:
|
1320 |
+
bias = bias * self.bscale
|
1321 |
+
|
1322 |
+
# Fast pass for linear activation.
|
1323 |
+
if self.activation_type == 'linear' and bias is not None:
|
1324 |
+
x = torch.addmm(bias.unsqueeze(0), x, weight.t())
|
1325 |
+
else:
|
1326 |
+
x = x.matmul(weight.t())
|
1327 |
+
x = bias_act.bias_act(x, bias, act=self.activation_type, impl=impl)
|
1328 |
+
|
1329 |
+
assert x.dtype == dtype
|
1330 |
+
return x
|
1331 |
+
|
1332 |
+
# pylint: enable=missing-function-docstring
|
models/stylegan_discriminator.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of discriminator described in StyleGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1812.04948.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.cuda.amp import autocast
|
15 |
+
|
16 |
+
__all__ = ['StyleGANDiscriminator']
|
17 |
+
|
18 |
+
# Resolutions allowed.
|
19 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
20 |
+
|
21 |
+
# Fused-scale options allowed.
|
22 |
+
_FUSED_SCALE_ALLOWED = [True, False, 'auto']
|
23 |
+
|
24 |
+
# pylint: disable=missing-function-docstring
|
25 |
+
|
26 |
+
class StyleGANDiscriminator(nn.Module):
|
27 |
+
"""Defines the discriminator network in StyleGAN.
|
28 |
+
|
29 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
30 |
+
range [-1, 1] as inputs.
|
31 |
+
|
32 |
+
Settings for the backbone:
|
33 |
+
|
34 |
+
(1) resolution: The resolution of the input image. (default: -1)
|
35 |
+
(2) init_res: Smallest resolution of the convolutional backbone.
|
36 |
+
(default: 4)
|
37 |
+
(3) image_channels: Number of channels of the input image. (default: 3)
|
38 |
+
(4) fused_scale: The strategy of fusing `conv2d` and `downsample` as one
|
39 |
+
operator. `True` means blocks from all resolutions will fuse. `False`
|
40 |
+
means blocks from all resolutions will not fuse. `auto` means blocks
|
41 |
+
from resolutions higher than (or equal to) `fused_scale_res` will fuse.
|
42 |
+
(default: `auto`)
|
43 |
+
(5) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample`
|
44 |
+
as one operator. This field only takes effect if `fused_scale` is set
|
45 |
+
as `auto`. (default: 128)
|
46 |
+
(6) use_wscale: Whether to use weight scaling. (default: True)
|
47 |
+
(7) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
|
48 |
+
(8) lr_mul: Learning rate multiplier for backbone. (default: 1.0)
|
49 |
+
(9) mbstd_groups: Group size for the minibatch standard deviation layer.
|
50 |
+
`0` means disable. (default: 4)
|
51 |
+
(10) mbstd_channels: Number of new channels (appended to the original
|
52 |
+
feature map) after the minibatch standard deviation layer. (default: 1)
|
53 |
+
(11) fmaps_base: Factor to control number of feature maps for each layer.
|
54 |
+
(default: 16 << 10)
|
55 |
+
(12) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
56 |
+
(13) filter_kernel: Kernel used for filtering (e.g., downsampling).
|
57 |
+
(default: (1, 2, 1))
|
58 |
+
(14) eps: A small value to avoid divide overflow. (default: 1e-8)
|
59 |
+
|
60 |
+
Settings for conditional model:
|
61 |
+
|
62 |
+
(1) label_dim: Dimension of the additional label for conditional generation.
|
63 |
+
In one-hot conditioning case, it is equal to the number of classes. If
|
64 |
+
set to 0, conditioning training will be disabled. (default: 0)
|
65 |
+
|
66 |
+
Runtime settings:
|
67 |
+
|
68 |
+
(1) enable_amp: Whether to enable automatic mixed precision training.
|
69 |
+
(default: False)
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self,
|
73 |
+
# Settings for backbone.
|
74 |
+
resolution=-1,
|
75 |
+
init_res=4,
|
76 |
+
image_channels=3,
|
77 |
+
fused_scale='auto',
|
78 |
+
fused_scale_res=128,
|
79 |
+
use_wscale=True,
|
80 |
+
wscale_gain=np.sqrt(2.0),
|
81 |
+
lr_mul=1.0,
|
82 |
+
mbstd_groups=4,
|
83 |
+
mbstd_channels=1,
|
84 |
+
fmaps_base=16 << 10,
|
85 |
+
fmaps_max=512,
|
86 |
+
filter_kernel=(1, 2, 1),
|
87 |
+
eps=1e-8,
|
88 |
+
# Settings for conditional model.
|
89 |
+
label_dim=0):
|
90 |
+
"""Initializes with basic settings.
|
91 |
+
|
92 |
+
Raises:
|
93 |
+
ValueError: If the `resolution` is not supported, or `fused_scale`
|
94 |
+
is not supported.
|
95 |
+
"""
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
99 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
100 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
101 |
+
if fused_scale not in _FUSED_SCALE_ALLOWED:
|
102 |
+
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
|
103 |
+
f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
|
104 |
+
|
105 |
+
self.init_res = init_res
|
106 |
+
self.init_res_log2 = int(np.log2(init_res))
|
107 |
+
self.resolution = resolution
|
108 |
+
self.final_res_log2 = int(np.log2(resolution))
|
109 |
+
self.image_channels = image_channels
|
110 |
+
self.fused_scale = fused_scale
|
111 |
+
self.fused_scale_res = fused_scale_res
|
112 |
+
self.use_wscale = use_wscale
|
113 |
+
self.wscale_gain = wscale_gain
|
114 |
+
self.lr_mul = lr_mul
|
115 |
+
self.mbstd_groups = mbstd_groups
|
116 |
+
self.mbstd_channels = mbstd_channels
|
117 |
+
self.fmaps_base = fmaps_base
|
118 |
+
self.fmaps_max = fmaps_max
|
119 |
+
self.filter_kernel = filter_kernel
|
120 |
+
self.eps = eps
|
121 |
+
self.label_dim = label_dim
|
122 |
+
|
123 |
+
# Level-of-details (used for progressive training).
|
124 |
+
self.register_buffer('lod', torch.zeros(()))
|
125 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
126 |
+
|
127 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
128 |
+
res = 2 ** res_log2
|
129 |
+
in_channels = self.get_nf(res)
|
130 |
+
out_channels = self.get_nf(res // 2)
|
131 |
+
block_idx = self.final_res_log2 - res_log2
|
132 |
+
|
133 |
+
# Input convolution layer for each resolution.
|
134 |
+
self.add_module(
|
135 |
+
f'input{block_idx}',
|
136 |
+
ConvLayer(in_channels=image_channels,
|
137 |
+
out_channels=in_channels,
|
138 |
+
kernel_size=1,
|
139 |
+
add_bias=True,
|
140 |
+
scale_factor=1,
|
141 |
+
fused_scale=False,
|
142 |
+
filter_kernel=None,
|
143 |
+
use_wscale=use_wscale,
|
144 |
+
wscale_gain=wscale_gain,
|
145 |
+
lr_mul=lr_mul,
|
146 |
+
activation_type='lrelu'))
|
147 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
148 |
+
f'FromRGB_lod{block_idx}/weight')
|
149 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
150 |
+
f'FromRGB_lod{block_idx}/bias')
|
151 |
+
|
152 |
+
# Convolution block for each resolution (except the last one).
|
153 |
+
if res != self.init_res:
|
154 |
+
# First layer (kernel 3x3) without downsampling.
|
155 |
+
layer_name = f'layer{2 * block_idx}'
|
156 |
+
self.add_module(
|
157 |
+
layer_name,
|
158 |
+
ConvLayer(in_channels=in_channels,
|
159 |
+
out_channels=in_channels,
|
160 |
+
kernel_size=3,
|
161 |
+
add_bias=True,
|
162 |
+
scale_factor=1,
|
163 |
+
fused_scale=False,
|
164 |
+
filter_kernel=None,
|
165 |
+
use_wscale=use_wscale,
|
166 |
+
wscale_gain=wscale_gain,
|
167 |
+
lr_mul=lr_mul,
|
168 |
+
activation_type='lrelu'))
|
169 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
170 |
+
f'{res}x{res}/Conv0/weight')
|
171 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
172 |
+
f'{res}x{res}/Conv0/bias')
|
173 |
+
|
174 |
+
# Second layer (kernel 3x3) with downsampling
|
175 |
+
layer_name = f'layer{2 * block_idx + 1}'
|
176 |
+
self.add_module(
|
177 |
+
layer_name,
|
178 |
+
ConvLayer(in_channels=in_channels,
|
179 |
+
out_channels=out_channels,
|
180 |
+
kernel_size=3,
|
181 |
+
add_bias=True,
|
182 |
+
scale_factor=2,
|
183 |
+
fused_scale=(res >= fused_scale_res
|
184 |
+
if fused_scale == 'auto'
|
185 |
+
else fused_scale),
|
186 |
+
filter_kernel=filter_kernel,
|
187 |
+
use_wscale=use_wscale,
|
188 |
+
wscale_gain=wscale_gain,
|
189 |
+
lr_mul=lr_mul,
|
190 |
+
activation_type='lrelu'))
|
191 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
192 |
+
f'{res}x{res}/Conv1_down/weight')
|
193 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
194 |
+
f'{res}x{res}/Conv1_down/bias')
|
195 |
+
|
196 |
+
# Convolution block for last resolution.
|
197 |
+
else:
|
198 |
+
self.mbstd = MiniBatchSTDLayer(groups=mbstd_groups,
|
199 |
+
new_channels=mbstd_channels,
|
200 |
+
eps=eps)
|
201 |
+
|
202 |
+
# First layer (kernel 3x3) without downsampling.
|
203 |
+
layer_name = f'layer{2 * block_idx}'
|
204 |
+
self.add_module(
|
205 |
+
layer_name,
|
206 |
+
ConvLayer(in_channels=in_channels + mbstd_channels,
|
207 |
+
out_channels=in_channels,
|
208 |
+
kernel_size=3,
|
209 |
+
add_bias=True,
|
210 |
+
scale_factor=1,
|
211 |
+
fused_scale=False,
|
212 |
+
filter_kernel=None,
|
213 |
+
use_wscale=use_wscale,
|
214 |
+
wscale_gain=wscale_gain,
|
215 |
+
lr_mul=lr_mul,
|
216 |
+
activation_type='lrelu'))
|
217 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
218 |
+
f'{res}x{res}/Conv/weight')
|
219 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
220 |
+
f'{res}x{res}/Conv/bias')
|
221 |
+
|
222 |
+
# Second layer, as a fully-connected layer.
|
223 |
+
layer_name = f'layer{2 * block_idx + 1}'
|
224 |
+
self.add_module(
|
225 |
+
f'layer{2 * block_idx + 1}',
|
226 |
+
DenseLayer(in_channels=in_channels * res * res,
|
227 |
+
out_channels=in_channels,
|
228 |
+
add_bias=True,
|
229 |
+
use_wscale=use_wscale,
|
230 |
+
wscale_gain=wscale_gain,
|
231 |
+
lr_mul=lr_mul,
|
232 |
+
activation_type='lrelu'))
|
233 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
234 |
+
f'{res}x{res}/Dense0/weight')
|
235 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
236 |
+
f'{res}x{res}/Dense0/bias')
|
237 |
+
|
238 |
+
# Final dense layer to output score.
|
239 |
+
self.output = DenseLayer(in_channels=in_channels,
|
240 |
+
out_channels=max(label_dim, 1),
|
241 |
+
add_bias=True,
|
242 |
+
use_wscale=use_wscale,
|
243 |
+
wscale_gain=1.0,
|
244 |
+
lr_mul=lr_mul,
|
245 |
+
activation_type='linear')
|
246 |
+
self.pth_to_tf_var_mapping['output.weight'] = (
|
247 |
+
f'{res}x{res}/Dense1/weight')
|
248 |
+
self.pth_to_tf_var_mapping['output.bias'] = (
|
249 |
+
f'{res}x{res}/Dense1/bias')
|
250 |
+
|
251 |
+
def get_nf(self, res):
|
252 |
+
"""Gets number of feature maps according to the given resolution."""
|
253 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
254 |
+
|
255 |
+
def forward(self, image, label=None, lod=None, enable_amp=False):
|
256 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
257 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
258 |
+
raise ValueError(f'The input tensor should be with shape '
|
259 |
+
f'[batch_size, channel, height, width], where '
|
260 |
+
f'`channel` equals to {self.image_channels}, '
|
261 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
262 |
+
f'But `{image.shape}` is received!')
|
263 |
+
|
264 |
+
lod = self.lod.item() if lod is None else lod
|
265 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
266 |
+
raise ValueError(f'Maximum level-of-details (lod) is '
|
267 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
268 |
+
f'but `{lod}` is received!')
|
269 |
+
|
270 |
+
if self.label_dim:
|
271 |
+
if label is None:
|
272 |
+
raise ValueError(f'Model requires an additional label '
|
273 |
+
f'(with dimension {self.label_dim}) as input, '
|
274 |
+
f'but no label is received!')
|
275 |
+
batch = image.shape[0]
|
276 |
+
if (label.ndim != 2 or label.shape != (batch, self.label_dim)):
|
277 |
+
raise ValueError(f'Input label should be with shape '
|
278 |
+
f'[batch_size, label_dim], where '
|
279 |
+
f'`batch_size` equals to {batch}, and '
|
280 |
+
f'`label_dim` equals to {self.label_dim}!\n'
|
281 |
+
f'But `{label.shape}` is received!')
|
282 |
+
label = label.to(dtype=torch.float32)
|
283 |
+
|
284 |
+
with autocast(enabled=enable_amp):
|
285 |
+
for res_log2 in range(
|
286 |
+
self.final_res_log2, self.init_res_log2 - 1, -1):
|
287 |
+
block_idx = current_lod = self.final_res_log2 - res_log2
|
288 |
+
if current_lod <= lod < current_lod + 1:
|
289 |
+
x = getattr(self, f'input{block_idx}')(image)
|
290 |
+
elif current_lod - 1 < lod < current_lod:
|
291 |
+
alpha = lod - np.floor(lod)
|
292 |
+
y = getattr(self, f'input{block_idx}')(image)
|
293 |
+
x = y * alpha + x * (1 - alpha)
|
294 |
+
if lod < current_lod + 1:
|
295 |
+
if res_log2 == self.init_res_log2:
|
296 |
+
x = self.mbstd(x)
|
297 |
+
x = getattr(self, f'layer{2 * block_idx}')(x)
|
298 |
+
x = getattr(self, f'layer{2 * block_idx + 1}')(x)
|
299 |
+
if lod > current_lod:
|
300 |
+
image = F.avg_pool2d(
|
301 |
+
image, kernel_size=2, stride=2, padding=0)
|
302 |
+
x = self.output(x)
|
303 |
+
|
304 |
+
if self.label_dim:
|
305 |
+
x = (x * label).sum(dim=1, keepdim=True)
|
306 |
+
|
307 |
+
results = {
|
308 |
+
'score': x,
|
309 |
+
'label': label
|
310 |
+
}
|
311 |
+
return results
|
312 |
+
|
313 |
+
|
314 |
+
class MiniBatchSTDLayer(nn.Module):
|
315 |
+
"""Implements the minibatch standard deviation layer."""
|
316 |
+
|
317 |
+
def __init__(self, groups, new_channels, eps):
|
318 |
+
super().__init__()
|
319 |
+
self.groups = groups
|
320 |
+
self.new_channels = new_channels
|
321 |
+
self.eps = eps
|
322 |
+
|
323 |
+
def extra_repr(self):
|
324 |
+
return (f'groups={self.groups}, '
|
325 |
+
f'new_channels={self.new_channels}, '
|
326 |
+
f'epsilon={self.eps}')
|
327 |
+
|
328 |
+
def forward(self, x):
|
329 |
+
if self.groups <= 1 or self.new_channels < 1:
|
330 |
+
return x
|
331 |
+
|
332 |
+
N, C, H, W = x.shape
|
333 |
+
G = min(self.groups, N) # Number of groups.
|
334 |
+
nC = self.new_channels # Number of channel groups.
|
335 |
+
c = C // nC # Channels per channel group.
|
336 |
+
|
337 |
+
y = x.reshape(G, -1, nC, c, H, W) # [GnFcHW]
|
338 |
+
y = y - y.mean(dim=0) # [GnFcHW]
|
339 |
+
y = y.square().mean(dim=0) # [nFcHW]
|
340 |
+
y = (y + self.eps).sqrt() # [nFcHW]
|
341 |
+
y = y.mean(dim=(2, 3, 4)) # [nF]
|
342 |
+
y = y.reshape(-1, nC, 1, 1) # [nF11]
|
343 |
+
y = y.repeat(G, 1, H, W) # [NFHW]
|
344 |
+
x = torch.cat((x, y), dim=1) # [N(C+F)HW]
|
345 |
+
|
346 |
+
return x
|
347 |
+
|
348 |
+
|
349 |
+
class Blur(torch.autograd.Function):
|
350 |
+
"""Defines blur operation with customized gradient computation."""
|
351 |
+
|
352 |
+
@staticmethod
|
353 |
+
def forward(ctx, x, kernel):
|
354 |
+
assert kernel.shape[2] == 3 and kernel.shape[3] == 3
|
355 |
+
ctx.save_for_backward(kernel)
|
356 |
+
y = F.conv2d(input=x,
|
357 |
+
weight=kernel,
|
358 |
+
bias=None,
|
359 |
+
stride=1,
|
360 |
+
padding=1,
|
361 |
+
groups=x.shape[1])
|
362 |
+
return y
|
363 |
+
|
364 |
+
@staticmethod
|
365 |
+
def backward(ctx, dy):
|
366 |
+
kernel, = ctx.saved_tensors
|
367 |
+
dx = BlurBackPropagation.apply(dy, kernel)
|
368 |
+
return dx, None, None
|
369 |
+
|
370 |
+
|
371 |
+
class BlurBackPropagation(torch.autograd.Function):
|
372 |
+
"""Defines the back propagation of blur operation.
|
373 |
+
|
374 |
+
NOTE: This is used to speed up the backward of gradient penalty.
|
375 |
+
"""
|
376 |
+
|
377 |
+
@staticmethod
|
378 |
+
def forward(ctx, dy, kernel):
|
379 |
+
ctx.save_for_backward(kernel)
|
380 |
+
dx = F.conv2d(input=dy,
|
381 |
+
weight=kernel.flip((2, 3)),
|
382 |
+
bias=None,
|
383 |
+
stride=1,
|
384 |
+
padding=1,
|
385 |
+
groups=dy.shape[1])
|
386 |
+
return dx
|
387 |
+
|
388 |
+
@staticmethod
|
389 |
+
def backward(ctx, ddx):
|
390 |
+
kernel, = ctx.saved_tensors
|
391 |
+
ddy = F.conv2d(input=ddx,
|
392 |
+
weight=kernel,
|
393 |
+
bias=None,
|
394 |
+
stride=1,
|
395 |
+
padding=1,
|
396 |
+
groups=ddx.shape[1])
|
397 |
+
return ddy, None, None
|
398 |
+
|
399 |
+
|
400 |
+
class ConvLayer(nn.Module):
|
401 |
+
"""Implements the convolutional layer.
|
402 |
+
|
403 |
+
If downsampling is needed (i.e., `scale_factor = 2`), the feature map will
|
404 |
+
be filtered with `filter_kernel` first. If `fused_scale` is set as `True`,
|
405 |
+
`conv2d` and `downsample` will be fused as one operator, using stride
|
406 |
+
convolution.
|
407 |
+
"""
|
408 |
+
|
409 |
+
def __init__(self,
|
410 |
+
in_channels,
|
411 |
+
out_channels,
|
412 |
+
kernel_size,
|
413 |
+
add_bias,
|
414 |
+
scale_factor,
|
415 |
+
fused_scale,
|
416 |
+
filter_kernel,
|
417 |
+
use_wscale,
|
418 |
+
wscale_gain,
|
419 |
+
lr_mul,
|
420 |
+
activation_type):
|
421 |
+
"""Initializes with layer settings.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
in_channels: Number of channels of the input tensor.
|
425 |
+
out_channels: Number of channels of the output tensor.
|
426 |
+
kernel_size: Size of the convolutional kernels.
|
427 |
+
add_bias: Whether to add bias onto the convolutional result.
|
428 |
+
scale_factor: Scale factor for downsampling. `1` means skip
|
429 |
+
downsampling.
|
430 |
+
fused_scale: Whether to fuse `conv2d` and `downsample` as one
|
431 |
+
operator, using stride convolution.
|
432 |
+
filter_kernel: Kernel used for filtering.
|
433 |
+
use_wscale: Whether to use weight scaling.
|
434 |
+
wscale_gain: Gain factor for weight scaling.
|
435 |
+
lr_mul: Learning multiplier for both weight and bias.
|
436 |
+
activation_type: Type of activation.
|
437 |
+
"""
|
438 |
+
super().__init__()
|
439 |
+
self.in_channels = in_channels
|
440 |
+
self.out_channels = out_channels
|
441 |
+
self.kernel_size = kernel_size
|
442 |
+
self.add_bias = add_bias
|
443 |
+
self.scale_factor = scale_factor
|
444 |
+
self.fused_scale = fused_scale
|
445 |
+
self.filter_kernel = filter_kernel
|
446 |
+
self.use_wscale = use_wscale
|
447 |
+
self.wscale_gain = wscale_gain
|
448 |
+
self.lr_mul = lr_mul
|
449 |
+
self.activation_type = activation_type
|
450 |
+
|
451 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
452 |
+
fan_in = kernel_size * kernel_size * in_channels
|
453 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
454 |
+
if use_wscale:
|
455 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
456 |
+
self.wscale = wscale * lr_mul
|
457 |
+
else:
|
458 |
+
self.weight = nn.Parameter(
|
459 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
460 |
+
self.wscale = lr_mul
|
461 |
+
|
462 |
+
if add_bias:
|
463 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
464 |
+
self.bscale = lr_mul
|
465 |
+
else:
|
466 |
+
self.bias = None
|
467 |
+
|
468 |
+
if scale_factor > 1:
|
469 |
+
assert filter_kernel is not None
|
470 |
+
kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1)
|
471 |
+
kernel = kernel.T.dot(kernel)
|
472 |
+
kernel = kernel / np.sum(kernel)
|
473 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
474 |
+
self.register_buffer('filter', torch.from_numpy(kernel))
|
475 |
+
|
476 |
+
if scale_factor > 1 and fused_scale: # use stride convolution.
|
477 |
+
self.stride = scale_factor
|
478 |
+
else:
|
479 |
+
self.stride = 1
|
480 |
+
self.padding = kernel_size // 2
|
481 |
+
|
482 |
+
assert activation_type in ['linear', 'relu', 'lrelu']
|
483 |
+
|
484 |
+
def extra_repr(self):
|
485 |
+
return (f'in_ch={self.in_channels}, '
|
486 |
+
f'out_ch={self.out_channels}, '
|
487 |
+
f'ksize={self.kernel_size}, '
|
488 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
489 |
+
f'bias={self.add_bias}, '
|
490 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
491 |
+
f'downsample={self.scale_factor}, '
|
492 |
+
f'fused_scale={self.fused_scale}, '
|
493 |
+
f'downsample_filter={self.filter_kernel}, '
|
494 |
+
f'act={self.activation_type}')
|
495 |
+
|
496 |
+
def forward(self, x):
|
497 |
+
if self.scale_factor > 1:
|
498 |
+
# Disable `autocast` for customized autograd function.
|
499 |
+
# Please check reference:
|
500 |
+
# https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions
|
501 |
+
with autocast(enabled=False):
|
502 |
+
f = self.filter.repeat(self.in_channels, 1, 1, 1)
|
503 |
+
x = Blur.apply(x.float(), f) # Always use FP32.
|
504 |
+
|
505 |
+
weight = self.weight
|
506 |
+
if self.wscale != 1.0:
|
507 |
+
weight = weight * self.wscale
|
508 |
+
bias = None
|
509 |
+
if self.bias is not None:
|
510 |
+
bias = self.bias
|
511 |
+
if self.bscale != 1.0:
|
512 |
+
bias = bias * self.bscale
|
513 |
+
|
514 |
+
if self.scale_factor > 1 and self.fused_scale:
|
515 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
516 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
517 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
|
518 |
+
x = F.conv2d(x,
|
519 |
+
weight=weight,
|
520 |
+
bias=bias,
|
521 |
+
stride=self.stride,
|
522 |
+
padding=self.padding)
|
523 |
+
if self.scale_factor > 1 and not self.fused_scale:
|
524 |
+
down = self.scale_factor
|
525 |
+
x = F.avg_pool2d(x, kernel_size=down, stride=down, padding=0)
|
526 |
+
|
527 |
+
if self.activation_type == 'linear':
|
528 |
+
pass
|
529 |
+
elif self.activation_type == 'relu':
|
530 |
+
x = F.relu(x, inplace=True)
|
531 |
+
elif self.activation_type == 'lrelu':
|
532 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
|
533 |
+
else:
|
534 |
+
raise NotImplementedError(f'Not implemented activation type '
|
535 |
+
f'`{self.activation_type}`!')
|
536 |
+
|
537 |
+
return x
|
538 |
+
|
539 |
+
|
540 |
+
class DenseLayer(nn.Module):
|
541 |
+
"""Implements the dense layer."""
|
542 |
+
|
543 |
+
def __init__(self,
|
544 |
+
in_channels,
|
545 |
+
out_channels,
|
546 |
+
add_bias,
|
547 |
+
use_wscale,
|
548 |
+
wscale_gain,
|
549 |
+
lr_mul,
|
550 |
+
activation_type):
|
551 |
+
"""Initializes with layer settings.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
in_channels: Number of channels of the input tensor.
|
555 |
+
out_channels: Number of channels of the output tensor.
|
556 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
557 |
+
use_wscale: Whether to use weight scaling.
|
558 |
+
wscale_gain: Gain factor for weight scaling.
|
559 |
+
lr_mul: Learning multiplier for both weight and bias.
|
560 |
+
activation_type: Type of activation.
|
561 |
+
"""
|
562 |
+
super().__init__()
|
563 |
+
self.in_channels = in_channels
|
564 |
+
self.out_channels = out_channels
|
565 |
+
self.add_bias = add_bias
|
566 |
+
self.use_wscale = use_wscale
|
567 |
+
self.wscale_gain = wscale_gain
|
568 |
+
self.lr_mul = lr_mul
|
569 |
+
self.activation_type = activation_type
|
570 |
+
|
571 |
+
weight_shape = (out_channels, in_channels)
|
572 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
573 |
+
if use_wscale:
|
574 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
575 |
+
self.wscale = wscale * lr_mul
|
576 |
+
else:
|
577 |
+
self.weight = nn.Parameter(
|
578 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
579 |
+
self.wscale = lr_mul
|
580 |
+
|
581 |
+
if add_bias:
|
582 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
583 |
+
self.bscale = lr_mul
|
584 |
+
else:
|
585 |
+
self.bias = None
|
586 |
+
|
587 |
+
assert activation_type in ['linear', 'relu', 'lrelu']
|
588 |
+
|
589 |
+
def extra_repr(self):
|
590 |
+
return (f'in_ch={self.in_channels}, '
|
591 |
+
f'out_ch={self.out_channels}, '
|
592 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
593 |
+
f'bias={self.add_bias}, '
|
594 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
595 |
+
f'act={self.activation_type}')
|
596 |
+
|
597 |
+
def forward(self, x):
|
598 |
+
if x.ndim != 2:
|
599 |
+
x = x.flatten(start_dim=1)
|
600 |
+
|
601 |
+
weight = self.weight
|
602 |
+
if self.wscale != 1.0:
|
603 |
+
weight = weight * self.wscale
|
604 |
+
bias = None
|
605 |
+
if self.bias is not None:
|
606 |
+
bias = self.bias
|
607 |
+
if self.bscale != 1.0:
|
608 |
+
bias = bias * self.bscale
|
609 |
+
|
610 |
+
x = F.linear(x, weight=weight, bias=bias)
|
611 |
+
|
612 |
+
if self.activation_type == 'linear':
|
613 |
+
pass
|
614 |
+
elif self.activation_type == 'relu':
|
615 |
+
x = F.relu(x, inplace=True)
|
616 |
+
elif self.activation_type == 'lrelu':
|
617 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
|
618 |
+
else:
|
619 |
+
raise NotImplementedError(f'Not implemented activation type '
|
620 |
+
f'`{self.activation_type}`!')
|
621 |
+
|
622 |
+
return x
|
623 |
+
|
624 |
+
# pylint: enable=missing-function-docstring
|
models/stylegan_generator.py
ADDED
@@ -0,0 +1,999 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of generator described in StyleGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1812.04948.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch.cuda.amp import autocast
|
15 |
+
|
16 |
+
from .utils.ops import all_gather
|
17 |
+
|
18 |
+
__all__ = ['StyleGANGenerator']
|
19 |
+
|
20 |
+
# Resolutions allowed.
|
21 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
22 |
+
|
23 |
+
# Fused-scale options allowed.
|
24 |
+
_FUSED_SCALE_ALLOWED = [True, False, 'auto']
|
25 |
+
|
26 |
+
# pylint: disable=missing-function-docstring
|
27 |
+
|
28 |
+
class StyleGANGenerator(nn.Module):
|
29 |
+
"""Defines the generator network in StyleGAN.
|
30 |
+
|
31 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
32 |
+
[-1, 1].
|
33 |
+
|
34 |
+
Settings for the mapping network:
|
35 |
+
|
36 |
+
(1) z_dim: Dimension of the input latent space, Z. (default: 512)
|
37 |
+
(2) w_dim: Dimension of the output latent space, W. (default: 512)
|
38 |
+
(3) repeat_w: Repeat w-code for different layers. (default: True)
|
39 |
+
(4) normalize_z: Whether to normalize the z-code. (default: True)
|
40 |
+
(5) mapping_layers: Number of layers of the mapping network. (default: 8)
|
41 |
+
(6) mapping_fmaps: Number of hidden channels of the mapping network.
|
42 |
+
(default: 512)
|
43 |
+
(7) mapping_use_wscale: Whether to use weight scaling for the mapping
|
44 |
+
network. (default: True)
|
45 |
+
(8) mapping_wscale_gain: The factor to control weight scaling for the
|
46 |
+
mapping network (default: sqrt(2.0))
|
47 |
+
(9) mapping_lr_mul: Learning rate multiplier for the mapping network.
|
48 |
+
(default: 0.01)
|
49 |
+
|
50 |
+
Settings for conditional generation:
|
51 |
+
|
52 |
+
(1) label_dim: Dimension of the additional label for conditional generation.
|
53 |
+
In one-hot conditioning case, it is equal to the number of classes. If
|
54 |
+
set to 0, conditioning training will be disabled. (default: 0)
|
55 |
+
(2) embedding_dim: Dimension of the embedding space, if needed.
|
56 |
+
(default: 512)
|
57 |
+
|
58 |
+
Settings for the synthesis network:
|
59 |
+
|
60 |
+
(1) resolution: The resolution of the output image. (default: -1)
|
61 |
+
(2) init_res: The initial resolution to start with convolution. (default: 4)
|
62 |
+
(3) image_channels: Number of channels of the output image. (default: 3)
|
63 |
+
(4) final_tanh: Whether to use `tanh` to control the final pixel range.
|
64 |
+
(default: False)
|
65 |
+
(5) fused_scale: The strategy of fusing `upsample` and `conv2d` as one
|
66 |
+
operator. `True` means blocks from all resolutions will fuse. `False`
|
67 |
+
means blocks from all resolutions will not fuse. `auto` means blocks
|
68 |
+
from resolutions higher than (or equal to) `fused_scale_res` will fuse.
|
69 |
+
(default: `auto`)
|
70 |
+
(6) fused_scale_res: Minimum resolution to fuse `conv2d` and `downsample`
|
71 |
+
as one operator. This field only takes effect if `fused_scale` is set
|
72 |
+
as `auto`. (default: 128)
|
73 |
+
(7) use_wscale: Whether to use weight scaling. (default: True)
|
74 |
+
(8) wscale_gain: The factor to control weight scaling. (default: sqrt(2.0))
|
75 |
+
(9) lr_mul: Learning rate multiplier for the synthesis network.
|
76 |
+
(default: 1.0)
|
77 |
+
(10) noise_type: Type of noise added to the convolutional results at each
|
78 |
+
layer. (default: `spatial`)
|
79 |
+
(11) fmaps_base: Factor to control number of feature maps for each layer.
|
80 |
+
(default: 16 << 10)
|
81 |
+
(12) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
82 |
+
(13) filter_kernel: Kernel used for filtering (e.g., downsampling).
|
83 |
+
(default: (1, 2, 1))
|
84 |
+
(14) eps: A small value to avoid divide overflow. (default: 1e-8)
|
85 |
+
|
86 |
+
Runtime settings:
|
87 |
+
|
88 |
+
(1) w_moving_decay: Decay factor for updating `w_avg`, which is used for
|
89 |
+
training only. Set `None` to disable. (default: None)
|
90 |
+
(2) sync_w_avg: Synchronizing the stats of `w_avg` across replicas. If set
|
91 |
+
as `True`, the stats will be more accurate, yet the speed maybe a little
|
92 |
+
bit slower. (default: False)
|
93 |
+
(3) style_mixing_prob: Probability to perform style mixing as a training
|
94 |
+
regularization. Set `None` to disable. (default: None)
|
95 |
+
(4) trunc_psi: Truncation psi, set `None` to disable. (default: None)
|
96 |
+
(5) trunc_layers: Number of layers to perform truncation. (default: None)
|
97 |
+
(6) noise_mode: Mode of the layer-wise noise. Support `none`, `random`,
|
98 |
+
`const`. (default: `const`)
|
99 |
+
(7) enable_amp: Whether to enable automatic mixed precision training.
|
100 |
+
(default: False)
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self,
|
104 |
+
# Settings for mapping network.
|
105 |
+
z_dim=512,
|
106 |
+
w_dim=512,
|
107 |
+
repeat_w=True,
|
108 |
+
normalize_z=True,
|
109 |
+
mapping_layers=8,
|
110 |
+
mapping_fmaps=512,
|
111 |
+
mapping_use_wscale=True,
|
112 |
+
mapping_wscale_gain=np.sqrt(2.0),
|
113 |
+
mapping_lr_mul=0.01,
|
114 |
+
# Settings for conditional generation.
|
115 |
+
label_dim=0,
|
116 |
+
embedding_dim=512,
|
117 |
+
# Settings for synthesis network.
|
118 |
+
resolution=-1,
|
119 |
+
init_res=4,
|
120 |
+
image_channels=3,
|
121 |
+
final_tanh=False,
|
122 |
+
fused_scale='auto',
|
123 |
+
fused_scale_res=128,
|
124 |
+
use_wscale=True,
|
125 |
+
wscale_gain=np.sqrt(2.0),
|
126 |
+
lr_mul=1.0,
|
127 |
+
noise_type='spatial',
|
128 |
+
fmaps_base=16 << 10,
|
129 |
+
fmaps_max=512,
|
130 |
+
filter_kernel=(1, 2, 1),
|
131 |
+
eps=1e-8):
|
132 |
+
"""Initializes with basic settings.
|
133 |
+
|
134 |
+
Raises:
|
135 |
+
ValueError: If the `resolution` is not supported, or `fused_scale`
|
136 |
+
is not supported.
|
137 |
+
"""
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
141 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
142 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
143 |
+
if fused_scale not in _FUSED_SCALE_ALLOWED:
|
144 |
+
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
|
145 |
+
f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
|
146 |
+
|
147 |
+
self.z_dim = z_dim
|
148 |
+
self.w_dim = w_dim
|
149 |
+
self.repeat_w = repeat_w
|
150 |
+
self.normalize_z = normalize_z
|
151 |
+
self.mapping_layers = mapping_layers
|
152 |
+
self.mapping_fmaps = mapping_fmaps
|
153 |
+
self.mapping_use_wscale = mapping_use_wscale
|
154 |
+
self.mapping_wscale_gain = mapping_wscale_gain
|
155 |
+
self.mapping_lr_mul = mapping_lr_mul
|
156 |
+
|
157 |
+
self.label_dim = label_dim
|
158 |
+
self.embedding_dim = embedding_dim
|
159 |
+
|
160 |
+
self.resolution = resolution
|
161 |
+
self.init_res = init_res
|
162 |
+
self.image_channels = image_channels
|
163 |
+
self.final_tanh = final_tanh
|
164 |
+
self.fused_scale = fused_scale
|
165 |
+
self.fused_scale_res = fused_scale_res
|
166 |
+
self.use_wscale = use_wscale
|
167 |
+
self.wscale_gain = wscale_gain
|
168 |
+
self.lr_mul = lr_mul
|
169 |
+
self.noise_type = noise_type.lower()
|
170 |
+
self.fmaps_base = fmaps_base
|
171 |
+
self.fmaps_max = fmaps_max
|
172 |
+
self.filter_kernel = filter_kernel
|
173 |
+
self.eps = eps
|
174 |
+
|
175 |
+
# Dimension of latent space, which is convenient for sampling.
|
176 |
+
self.latent_dim = (z_dim,)
|
177 |
+
|
178 |
+
# Number of synthesis (convolutional) layers.
|
179 |
+
self.num_layers = int(np.log2(resolution // init_res * 2)) * 2
|
180 |
+
|
181 |
+
self.mapping = MappingNetwork(input_dim=z_dim,
|
182 |
+
output_dim=w_dim,
|
183 |
+
num_outputs=self.num_layers,
|
184 |
+
repeat_output=repeat_w,
|
185 |
+
normalize_input=normalize_z,
|
186 |
+
num_layers=mapping_layers,
|
187 |
+
hidden_dim=mapping_fmaps,
|
188 |
+
use_wscale=mapping_use_wscale,
|
189 |
+
wscale_gain=mapping_wscale_gain,
|
190 |
+
lr_mul=mapping_lr_mul,
|
191 |
+
label_dim=label_dim,
|
192 |
+
embedding_dim=embedding_dim,
|
193 |
+
eps=eps)
|
194 |
+
|
195 |
+
# This is used for truncation trick.
|
196 |
+
if self.repeat_w:
|
197 |
+
self.register_buffer('w_avg', torch.zeros(w_dim))
|
198 |
+
else:
|
199 |
+
self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim))
|
200 |
+
|
201 |
+
self.synthesis = SynthesisNetwork(resolution=resolution,
|
202 |
+
init_res=init_res,
|
203 |
+
w_dim=w_dim,
|
204 |
+
image_channels=image_channels,
|
205 |
+
final_tanh=final_tanh,
|
206 |
+
fused_scale=fused_scale,
|
207 |
+
fused_scale_res=fused_scale_res,
|
208 |
+
use_wscale=use_wscale,
|
209 |
+
wscale_gain=wscale_gain,
|
210 |
+
lr_mul=lr_mul,
|
211 |
+
noise_type=noise_type,
|
212 |
+
fmaps_base=fmaps_base,
|
213 |
+
fmaps_max=fmaps_max,
|
214 |
+
filter_kernel=filter_kernel,
|
215 |
+
eps=eps)
|
216 |
+
|
217 |
+
self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
|
218 |
+
for key, val in self.mapping.pth_to_tf_var_mapping.items():
|
219 |
+
self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
|
220 |
+
for key, val in self.synthesis.pth_to_tf_var_mapping.items():
|
221 |
+
self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
|
222 |
+
|
223 |
+
def set_space_of_latent(self, space_of_latent):
|
224 |
+
"""Sets the space to which the latent code belong.
|
225 |
+
|
226 |
+
See `SynthesisNetwork` for more details.
|
227 |
+
"""
|
228 |
+
self.synthesis.set_space_of_latent(space_of_latent)
|
229 |
+
|
230 |
+
def forward(self,
|
231 |
+
z,
|
232 |
+
label=None,
|
233 |
+
lod=None,
|
234 |
+
w_moving_decay=None,
|
235 |
+
sync_w_avg=False,
|
236 |
+
style_mixing_prob=None,
|
237 |
+
trunc_psi=None,
|
238 |
+
trunc_layers=None,
|
239 |
+
noise_mode='const',
|
240 |
+
enable_amp=False):
|
241 |
+
mapping_results = self.mapping(z, label)
|
242 |
+
|
243 |
+
w = mapping_results['w']
|
244 |
+
if self.training and w_moving_decay is not None:
|
245 |
+
if sync_w_avg:
|
246 |
+
batch_w_avg = all_gather(w.detach()).mean(dim=0)
|
247 |
+
else:
|
248 |
+
batch_w_avg = w.detach().mean(dim=0)
|
249 |
+
self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay))
|
250 |
+
|
251 |
+
wp = mapping_results.pop('wp')
|
252 |
+
if self.training and style_mixing_prob is not None:
|
253 |
+
if np.random.uniform() < style_mixing_prob:
|
254 |
+
new_z = torch.randn_like(z)
|
255 |
+
new_wp = self.mapping(new_z, label)['wp']
|
256 |
+
lod = self.synthesis.lod.item() if lod is None else lod
|
257 |
+
current_layers = self.num_layers - int(lod) * 2
|
258 |
+
mixing_cutoff = np.random.randint(1, current_layers)
|
259 |
+
wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:]
|
260 |
+
|
261 |
+
if not self.training:
|
262 |
+
trunc_psi = 1.0 if trunc_psi is None else trunc_psi
|
263 |
+
trunc_layers = 0 if trunc_layers is None else trunc_layers
|
264 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
265 |
+
w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers]
|
266 |
+
wp[:, :trunc_layers] = w_avg.lerp(
|
267 |
+
wp[:, :trunc_layers], trunc_psi)
|
268 |
+
|
269 |
+
with autocast(enabled=enable_amp):
|
270 |
+
synthesis_results = self.synthesis(wp,
|
271 |
+
lod=lod,
|
272 |
+
noise_mode=noise_mode)
|
273 |
+
|
274 |
+
return {**mapping_results, **synthesis_results}
|
275 |
+
|
276 |
+
|
277 |
+
class MappingNetwork(nn.Module):
|
278 |
+
"""Implements the latent space mapping module.
|
279 |
+
|
280 |
+
Basically, this module executes several dense layers in sequence, and the
|
281 |
+
label embedding if needed.
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self,
|
285 |
+
input_dim,
|
286 |
+
output_dim,
|
287 |
+
num_outputs,
|
288 |
+
repeat_output,
|
289 |
+
normalize_input,
|
290 |
+
num_layers,
|
291 |
+
hidden_dim,
|
292 |
+
use_wscale,
|
293 |
+
wscale_gain,
|
294 |
+
lr_mul,
|
295 |
+
label_dim,
|
296 |
+
embedding_dim,
|
297 |
+
eps):
|
298 |
+
super().__init__()
|
299 |
+
|
300 |
+
self.input_dim = input_dim
|
301 |
+
self.output_dim = output_dim
|
302 |
+
self.num_outputs = num_outputs
|
303 |
+
self.repeat_output = repeat_output
|
304 |
+
self.normalize_input = normalize_input
|
305 |
+
self.num_layers = num_layers
|
306 |
+
self.hidden_dim = hidden_dim
|
307 |
+
self.use_wscale = use_wscale
|
308 |
+
self.wscale_gain = wscale_gain
|
309 |
+
self.lr_mul = lr_mul
|
310 |
+
self.label_dim = label_dim
|
311 |
+
self.embedding_dim = embedding_dim
|
312 |
+
self.eps = eps
|
313 |
+
|
314 |
+
self.pth_to_tf_var_mapping = {}
|
315 |
+
|
316 |
+
if normalize_input:
|
317 |
+
self.norm = PixelNormLayer(dim=1, eps=eps)
|
318 |
+
|
319 |
+
if self.label_dim > 0:
|
320 |
+
input_dim = input_dim + embedding_dim
|
321 |
+
self.embedding = nn.Parameter(
|
322 |
+
torch.randn(label_dim, embedding_dim))
|
323 |
+
self.pth_to_tf_var_mapping['embedding'] = 'LabelConcat/weight'
|
324 |
+
|
325 |
+
if num_outputs is not None and not repeat_output:
|
326 |
+
output_dim = output_dim * num_outputs
|
327 |
+
for i in range(num_layers):
|
328 |
+
in_channels = (input_dim if i == 0 else hidden_dim)
|
329 |
+
out_channels = (output_dim if i == (num_layers - 1) else hidden_dim)
|
330 |
+
self.add_module(f'dense{i}',
|
331 |
+
DenseLayer(in_channels=in_channels,
|
332 |
+
out_channels=out_channels,
|
333 |
+
add_bias=True,
|
334 |
+
use_wscale=use_wscale,
|
335 |
+
wscale_gain=wscale_gain,
|
336 |
+
lr_mul=lr_mul,
|
337 |
+
activation_type='lrelu'))
|
338 |
+
self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
|
339 |
+
self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
|
340 |
+
|
341 |
+
def forward(self, z, label=None):
|
342 |
+
if z.ndim != 2 or z.shape[1] != self.input_dim:
|
343 |
+
raise ValueError(f'Input latent code should be with shape '
|
344 |
+
f'[batch_size, input_dim], where '
|
345 |
+
f'`input_dim` equals to {self.input_dim}!\n'
|
346 |
+
f'But `{z.shape}` is received!')
|
347 |
+
|
348 |
+
if self.label_dim > 0:
|
349 |
+
if label is None:
|
350 |
+
raise ValueError(f'Model requires an additional label '
|
351 |
+
f'(with dimension {self.label_dim}) as input, '
|
352 |
+
f'but no label is received!')
|
353 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_dim):
|
354 |
+
raise ValueError(f'Input label should be with shape '
|
355 |
+
f'[batch_size, label_dim], where '
|
356 |
+
f'`batch_size` equals to that of '
|
357 |
+
f'latent codes ({z.shape[0]}) and '
|
358 |
+
f'`label_dim` equals to {self.label_dim}!\n'
|
359 |
+
f'But `{label.shape}` is received!')
|
360 |
+
label = label.to(dtype=torch.float32)
|
361 |
+
embedding = torch.matmul(label, self.embedding)
|
362 |
+
z = torch.cat((z, embedding), dim=1)
|
363 |
+
|
364 |
+
if self.normalize_input:
|
365 |
+
w = self.norm(z)
|
366 |
+
else:
|
367 |
+
w = z
|
368 |
+
|
369 |
+
for i in range(self.num_layers):
|
370 |
+
w = getattr(self, f'dense{i}')(w)
|
371 |
+
|
372 |
+
wp = None
|
373 |
+
if self.num_outputs is not None:
|
374 |
+
if self.repeat_output:
|
375 |
+
wp = w.unsqueeze(1).repeat((1, self.num_outputs, 1))
|
376 |
+
else:
|
377 |
+
wp = w.reshape(-1, self.num_outputs, self.output_dim)
|
378 |
+
|
379 |
+
results = {
|
380 |
+
'z': z,
|
381 |
+
'label': label,
|
382 |
+
'w': w,
|
383 |
+
'wp': wp,
|
384 |
+
}
|
385 |
+
if self.label_dim > 0:
|
386 |
+
results['embedding'] = embedding
|
387 |
+
return results
|
388 |
+
|
389 |
+
|
390 |
+
class SynthesisNetwork(nn.Module):
|
391 |
+
"""Implements the image synthesis module.
|
392 |
+
|
393 |
+
Basically, this module executes several convolutional layers in sequence.
|
394 |
+
"""
|
395 |
+
|
396 |
+
def __init__(self,
|
397 |
+
resolution,
|
398 |
+
init_res,
|
399 |
+
w_dim,
|
400 |
+
image_channels,
|
401 |
+
final_tanh,
|
402 |
+
fused_scale,
|
403 |
+
fused_scale_res,
|
404 |
+
use_wscale,
|
405 |
+
wscale_gain,
|
406 |
+
lr_mul,
|
407 |
+
noise_type,
|
408 |
+
fmaps_base,
|
409 |
+
fmaps_max,
|
410 |
+
filter_kernel,
|
411 |
+
eps):
|
412 |
+
super().__init__()
|
413 |
+
|
414 |
+
self.init_res = init_res
|
415 |
+
self.init_res_log2 = int(np.log2(init_res))
|
416 |
+
self.resolution = resolution
|
417 |
+
self.final_res_log2 = int(np.log2(resolution))
|
418 |
+
self.w_dim = w_dim
|
419 |
+
self.image_channels = image_channels
|
420 |
+
self.final_tanh = final_tanh
|
421 |
+
self.fused_scale = fused_scale
|
422 |
+
self.fused_scale_res = fused_scale_res
|
423 |
+
self.use_wscale = use_wscale
|
424 |
+
self.wscale_gain = wscale_gain
|
425 |
+
self.lr_mul = lr_mul
|
426 |
+
self.noise_type = noise_type.lower()
|
427 |
+
self.fmaps_base = fmaps_base
|
428 |
+
self.fmaps_max = fmaps_max
|
429 |
+
self.eps = eps
|
430 |
+
|
431 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
432 |
+
|
433 |
+
# Level-of-details (used for progressive training).
|
434 |
+
self.register_buffer('lod', torch.zeros(()))
|
435 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
436 |
+
|
437 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
438 |
+
res = 2 ** res_log2
|
439 |
+
in_channels = self.get_nf(res // 2)
|
440 |
+
out_channels = self.get_nf(res)
|
441 |
+
block_idx = res_log2 - self.init_res_log2
|
442 |
+
|
443 |
+
# First layer (kernel 3x3) with upsampling
|
444 |
+
layer_name = f'layer{2 * block_idx}'
|
445 |
+
if res == self.init_res:
|
446 |
+
self.add_module(layer_name,
|
447 |
+
ModulateConvLayer(in_channels=0,
|
448 |
+
out_channels=out_channels,
|
449 |
+
resolution=res,
|
450 |
+
w_dim=w_dim,
|
451 |
+
kernel_size=None,
|
452 |
+
add_bias=True,
|
453 |
+
scale_factor=None,
|
454 |
+
fused_scale=None,
|
455 |
+
filter_kernel=None,
|
456 |
+
use_wscale=use_wscale,
|
457 |
+
wscale_gain=wscale_gain,
|
458 |
+
lr_mul=lr_mul,
|
459 |
+
noise_type=noise_type,
|
460 |
+
activation_type='lrelu',
|
461 |
+
use_style=True,
|
462 |
+
eps=eps))
|
463 |
+
tf_layer_name = 'Const'
|
464 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.const'] = (
|
465 |
+
f'{res}x{res}/{tf_layer_name}/const')
|
466 |
+
else:
|
467 |
+
self.add_module(
|
468 |
+
layer_name,
|
469 |
+
ModulateConvLayer(in_channels=in_channels,
|
470 |
+
out_channels=out_channels,
|
471 |
+
resolution=res,
|
472 |
+
w_dim=w_dim,
|
473 |
+
kernel_size=3,
|
474 |
+
add_bias=True,
|
475 |
+
scale_factor=2,
|
476 |
+
fused_scale=(res >= fused_scale_res
|
477 |
+
if fused_scale == 'auto'
|
478 |
+
else fused_scale),
|
479 |
+
filter_kernel=filter_kernel,
|
480 |
+
use_wscale=use_wscale,
|
481 |
+
wscale_gain=wscale_gain,
|
482 |
+
lr_mul=lr_mul,
|
483 |
+
noise_type=noise_type,
|
484 |
+
activation_type='lrelu',
|
485 |
+
use_style=True,
|
486 |
+
eps=eps))
|
487 |
+
tf_layer_name = 'Conv0_up'
|
488 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
489 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
490 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
491 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
492 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
493 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
|
494 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
495 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
|
496 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
497 |
+
f'{res}x{res}/{tf_layer_name}/Noise/weight')
|
498 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
499 |
+
f'noise{2 * block_idx}')
|
500 |
+
|
501 |
+
# Second layer (kernel 3x3) without upsampling.
|
502 |
+
layer_name = f'layer{2 * block_idx + 1}'
|
503 |
+
self.add_module(layer_name,
|
504 |
+
ModulateConvLayer(in_channels=out_channels,
|
505 |
+
out_channels=out_channels,
|
506 |
+
resolution=res,
|
507 |
+
w_dim=w_dim,
|
508 |
+
kernel_size=3,
|
509 |
+
add_bias=True,
|
510 |
+
scale_factor=1,
|
511 |
+
fused_scale=False,
|
512 |
+
filter_kernel=None,
|
513 |
+
use_wscale=use_wscale,
|
514 |
+
wscale_gain=wscale_gain,
|
515 |
+
lr_mul=lr_mul,
|
516 |
+
noise_type=noise_type,
|
517 |
+
activation_type='lrelu',
|
518 |
+
use_style=True,
|
519 |
+
eps=eps))
|
520 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
521 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
522 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
523 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
524 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
525 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
526 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
|
527 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
528 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
|
529 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
530 |
+
f'{res}x{res}/{tf_layer_name}/Noise/weight')
|
531 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
532 |
+
f'noise{2 * block_idx + 1}')
|
533 |
+
|
534 |
+
# Output convolution layer for each resolution.
|
535 |
+
self.add_module(f'output{block_idx}',
|
536 |
+
ModulateConvLayer(in_channels=out_channels,
|
537 |
+
out_channels=image_channels,
|
538 |
+
resolution=res,
|
539 |
+
w_dim=w_dim,
|
540 |
+
kernel_size=1,
|
541 |
+
add_bias=True,
|
542 |
+
scale_factor=1,
|
543 |
+
fused_scale=False,
|
544 |
+
filter_kernel=None,
|
545 |
+
use_wscale=use_wscale,
|
546 |
+
wscale_gain=1.0,
|
547 |
+
lr_mul=lr_mul,
|
548 |
+
noise_type='none',
|
549 |
+
activation_type='linear',
|
550 |
+
use_style=False,
|
551 |
+
eps=eps))
|
552 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
|
553 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
|
554 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
|
555 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
|
556 |
+
|
557 |
+
def get_nf(self, res):
|
558 |
+
"""Gets number of feature maps according to the given resolution."""
|
559 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
560 |
+
|
561 |
+
def set_space_of_latent(self, space_of_latent):
|
562 |
+
"""Sets the space to which the latent code belong.
|
563 |
+
|
564 |
+
This function is particularly used for choosing how to inject the latent
|
565 |
+
code into the convolutional layers. The original generator will take a
|
566 |
+
W-Space code and apply it for style modulation after an affine
|
567 |
+
transformation. But, sometimes, it may need to directly feed an already
|
568 |
+
affine-transformed code into the convolutional layer, e.g., when
|
569 |
+
training an encoder for GAN inversion. We term the transformed space as
|
570 |
+
Style Space (or Y-Space). This function is designed to tell the
|
571 |
+
convolutional layers how to use the input code.
|
572 |
+
|
573 |
+
Args:
|
574 |
+
space_of_latent: The space to which the latent code belong. Case
|
575 |
+
insensitive. Support `W` and `Y`.
|
576 |
+
"""
|
577 |
+
space_of_latent = space_of_latent.upper()
|
578 |
+
for module in self.modules():
|
579 |
+
if isinstance(module, ModulateConvLayer) and module.use_style:
|
580 |
+
setattr(module, 'space_of_latent', space_of_latent)
|
581 |
+
|
582 |
+
def forward(self, wp, lod=None, noise_mode='const'):
|
583 |
+
lod = self.lod.item() if lod is None else lod
|
584 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
585 |
+
raise ValueError(f'Maximum level-of-details (lod) is '
|
586 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
587 |
+
f'but `{lod}` is received!')
|
588 |
+
|
589 |
+
results = {'wp': wp}
|
590 |
+
x = None
|
591 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
592 |
+
current_lod = self.final_res_log2 - res_log2
|
593 |
+
block_idx = res_log2 - self.init_res_log2
|
594 |
+
if lod < current_lod + 1:
|
595 |
+
layer = getattr(self, f'layer{2 * block_idx}')
|
596 |
+
x, style = layer(x, wp[:, 2 * block_idx], noise_mode)
|
597 |
+
results[f'style{2 * block_idx}'] = style
|
598 |
+
layer = getattr(self, f'layer{2 * block_idx + 1}')
|
599 |
+
x, style = layer(x, wp[:, 2 * block_idx + 1], noise_mode)
|
600 |
+
results[f'style{2 * block_idx + 1}'] = style
|
601 |
+
if current_lod - 1 < lod <= current_lod:
|
602 |
+
image = getattr(self, f'output{block_idx}')(x)
|
603 |
+
elif current_lod < lod < current_lod + 1:
|
604 |
+
alpha = np.ceil(lod) - lod
|
605 |
+
temp = getattr(self, f'output{block_idx}')(x)
|
606 |
+
image = F.interpolate(image, scale_factor=2, mode='nearest')
|
607 |
+
image = temp * alpha + image * (1 - alpha)
|
608 |
+
elif lod >= current_lod + 1:
|
609 |
+
image = F.interpolate(image, scale_factor=2, mode='nearest')
|
610 |
+
|
611 |
+
if self.final_tanh:
|
612 |
+
image = torch.tanh(image)
|
613 |
+
results['image'] = image
|
614 |
+
return results
|
615 |
+
|
616 |
+
|
617 |
+
class PixelNormLayer(nn.Module):
|
618 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
619 |
+
|
620 |
+
def __init__(self, dim, eps):
|
621 |
+
super().__init__()
|
622 |
+
self.dim = dim
|
623 |
+
self.eps = eps
|
624 |
+
|
625 |
+
def extra_repr(self):
|
626 |
+
return f'dim={self.dim}, epsilon={self.eps}'
|
627 |
+
|
628 |
+
def forward(self, x):
|
629 |
+
scale = (x.square().mean(dim=self.dim, keepdim=True) + self.eps).rsqrt()
|
630 |
+
return x * scale
|
631 |
+
|
632 |
+
|
633 |
+
class Blur(torch.autograd.Function):
|
634 |
+
"""Defines blur operation with customized gradient computation."""
|
635 |
+
|
636 |
+
@staticmethod
|
637 |
+
def forward(ctx, x, kernel):
|
638 |
+
assert kernel.shape[2] == 3 and kernel.shape[3] == 3
|
639 |
+
ctx.save_for_backward(kernel)
|
640 |
+
y = F.conv2d(input=x,
|
641 |
+
weight=kernel,
|
642 |
+
bias=None,
|
643 |
+
stride=1,
|
644 |
+
padding=1,
|
645 |
+
groups=x.shape[1])
|
646 |
+
return y
|
647 |
+
|
648 |
+
@staticmethod
|
649 |
+
def backward(ctx, dy):
|
650 |
+
kernel, = ctx.saved_tensors
|
651 |
+
dx = F.conv2d(input=dy,
|
652 |
+
weight=kernel.flip((2, 3)),
|
653 |
+
bias=None,
|
654 |
+
stride=1,
|
655 |
+
padding=1,
|
656 |
+
groups=dy.shape[1])
|
657 |
+
return dx, None, None
|
658 |
+
|
659 |
+
|
660 |
+
class ModulateConvLayer(nn.Module):
|
661 |
+
"""Implements the convolutional layer with style modulation."""
|
662 |
+
|
663 |
+
def __init__(self,
|
664 |
+
in_channels,
|
665 |
+
out_channels,
|
666 |
+
resolution,
|
667 |
+
w_dim,
|
668 |
+
kernel_size,
|
669 |
+
add_bias,
|
670 |
+
scale_factor,
|
671 |
+
fused_scale,
|
672 |
+
filter_kernel,
|
673 |
+
use_wscale,
|
674 |
+
wscale_gain,
|
675 |
+
lr_mul,
|
676 |
+
noise_type,
|
677 |
+
activation_type,
|
678 |
+
use_style,
|
679 |
+
eps):
|
680 |
+
"""Initializes with layer settings.
|
681 |
+
|
682 |
+
Args:
|
683 |
+
in_channels: Number of channels of the input tensor.
|
684 |
+
out_channels: Number of channels of the output tensor.
|
685 |
+
resolution: Resolution of the output tensor.
|
686 |
+
w_dim: Dimension of W space for style modulation.
|
687 |
+
kernel_size: Size of the convolutional kernels.
|
688 |
+
add_bias: Whether to add bias onto the convolutional result.
|
689 |
+
scale_factor: Scale factor for upsampling.
|
690 |
+
fused_scale: Whether to fuse `upsample` and `conv2d` as one
|
691 |
+
operator, using transpose convolution.
|
692 |
+
filter_kernel: Kernel used for filtering.
|
693 |
+
use_wscale: Whether to use weight scaling.
|
694 |
+
wscale_gain: Gain factor for weight scaling.
|
695 |
+
lr_mul: Learning multiplier for both weight and bias.
|
696 |
+
noise_type: Type of noise added to the feature map after the
|
697 |
+
convolution (if needed). Support `none`, `spatial` and
|
698 |
+
`channel`.
|
699 |
+
activation_type: Type of activation.
|
700 |
+
use_style: Whether to apply style modulation.
|
701 |
+
eps: A small value to avoid divide overflow.
|
702 |
+
"""
|
703 |
+
super().__init__()
|
704 |
+
|
705 |
+
self.in_channels = in_channels
|
706 |
+
self.out_channels = out_channels
|
707 |
+
self.resolution = resolution
|
708 |
+
self.w_dim = w_dim
|
709 |
+
self.kernel_size = kernel_size
|
710 |
+
self.add_bias = add_bias
|
711 |
+
self.scale_factor = scale_factor
|
712 |
+
self.fused_scale = fused_scale
|
713 |
+
self.filter_kernel = filter_kernel
|
714 |
+
self.use_wscale = use_wscale
|
715 |
+
self.wscale_gain = wscale_gain
|
716 |
+
self.lr_mul = lr_mul
|
717 |
+
self.noise_type = noise_type.lower()
|
718 |
+
self.activation_type = activation_type
|
719 |
+
self.use_style = use_style
|
720 |
+
self.eps = eps
|
721 |
+
|
722 |
+
# Set up noise.
|
723 |
+
if self.noise_type == 'none':
|
724 |
+
pass
|
725 |
+
elif self.noise_type == 'spatial':
|
726 |
+
self.register_buffer(
|
727 |
+
'noise', torch.randn(1, 1, resolution, resolution))
|
728 |
+
self.noise_strength = nn.Parameter(
|
729 |
+
torch.zeros(1, out_channels, 1, 1))
|
730 |
+
elif self.noise_type == 'channel':
|
731 |
+
self.register_buffer(
|
732 |
+
'noise', torch.randn(1, out_channels, 1, 1))
|
733 |
+
self.noise_strength = nn.Parameter(
|
734 |
+
torch.zeros(1, 1, resolution, resolution))
|
735 |
+
else:
|
736 |
+
raise NotImplementedError(f'Not implemented noise type: '
|
737 |
+
f'`{noise_type}`!')
|
738 |
+
|
739 |
+
# Set up bias.
|
740 |
+
if add_bias:
|
741 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
742 |
+
self.bscale = lr_mul
|
743 |
+
else:
|
744 |
+
self.bias = None
|
745 |
+
|
746 |
+
# Set up activation.
|
747 |
+
assert activation_type in ['linear', 'relu', 'lrelu']
|
748 |
+
|
749 |
+
# Set up style.
|
750 |
+
if use_style:
|
751 |
+
self.space_of_latent = 'W'
|
752 |
+
self.style = DenseLayer(in_channels=w_dim,
|
753 |
+
out_channels=out_channels * 2,
|
754 |
+
add_bias=True,
|
755 |
+
use_wscale=use_wscale,
|
756 |
+
wscale_gain=1.0,
|
757 |
+
lr_mul=1.0,
|
758 |
+
activation_type='linear')
|
759 |
+
|
760 |
+
if in_channels == 0: # First layer.
|
761 |
+
self.const = nn.Parameter(
|
762 |
+
torch.ones(1, out_channels, resolution, resolution))
|
763 |
+
return
|
764 |
+
|
765 |
+
# Set up weight.
|
766 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
767 |
+
fan_in = kernel_size * kernel_size * in_channels
|
768 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
769 |
+
if use_wscale:
|
770 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
771 |
+
self.wscale = wscale * lr_mul
|
772 |
+
else:
|
773 |
+
self.weight = nn.Parameter(
|
774 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
775 |
+
self.wscale = lr_mul
|
776 |
+
|
777 |
+
# Set up upsampling filter (if needed).
|
778 |
+
if scale_factor > 1:
|
779 |
+
assert filter_kernel is not None
|
780 |
+
kernel = np.array(filter_kernel, dtype=np.float32).reshape(1, -1)
|
781 |
+
kernel = kernel.T.dot(kernel)
|
782 |
+
kernel = kernel / np.sum(kernel)
|
783 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
784 |
+
self.register_buffer('filter', torch.from_numpy(kernel))
|
785 |
+
|
786 |
+
if scale_factor > 1 and fused_scale: # use transpose convolution.
|
787 |
+
self.stride = scale_factor
|
788 |
+
else:
|
789 |
+
self.stride = 1
|
790 |
+
self.padding = kernel_size // 2
|
791 |
+
|
792 |
+
def extra_repr(self):
|
793 |
+
return (f'in_ch={self.in_channels}, '
|
794 |
+
f'out_ch={self.out_channels}, '
|
795 |
+
f'ksize={self.kernel_size}, '
|
796 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
797 |
+
f'bias={self.add_bias}, '
|
798 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
799 |
+
f'upsample={self.scale_factor}, '
|
800 |
+
f'fused_scale={self.fused_scale}, '
|
801 |
+
f'upsample_filter={self.filter_kernel}, '
|
802 |
+
f'noise_type={self.noise_type}, '
|
803 |
+
f'act={self.activation_type}, '
|
804 |
+
f'use_style={self.use_style}')
|
805 |
+
|
806 |
+
def forward_style(self, w):
|
807 |
+
"""Gets style code from the given input.
|
808 |
+
|
809 |
+
More specifically, if the input is from W-Space, it will be projected by
|
810 |
+
an affine transformation. If it is from the Style Space (Y-Space), no
|
811 |
+
operation is required.
|
812 |
+
|
813 |
+
NOTE: For codes from Y-Space, we use slicing to make sure the dimension
|
814 |
+
is correct, in case that the code is padded before fed into this layer.
|
815 |
+
"""
|
816 |
+
space_of_latent = self.space_of_latent.upper()
|
817 |
+
if space_of_latent == 'W':
|
818 |
+
if w.ndim != 2 or w.shape[1] != self.w_dim:
|
819 |
+
raise ValueError(f'The input tensor should be with shape '
|
820 |
+
f'[batch_size, w_dim], where '
|
821 |
+
f'`w_dim` equals to {self.w_dim}!\n'
|
822 |
+
f'But `{w.shape}` is received!')
|
823 |
+
style = self.style(w)
|
824 |
+
elif space_of_latent == 'Y':
|
825 |
+
if w.ndim != 2 or w.shape[1] < self.out_channels * 2:
|
826 |
+
raise ValueError(f'The input tensor should be with shape '
|
827 |
+
f'[batch_size, y_dim], where '
|
828 |
+
f'`y_dim` equals to {self.out_channels * 2}!\n'
|
829 |
+
f'But `{w.shape}` is received!')
|
830 |
+
style = w[:, :self.out_channels * 2]
|
831 |
+
else:
|
832 |
+
raise NotImplementedError(f'Not implemented `space_of_latent`: '
|
833 |
+
f'`{space_of_latent}`!')
|
834 |
+
return style
|
835 |
+
|
836 |
+
def forward(self, x, w=None, noise_mode='const'):
|
837 |
+
if self.in_channels == 0:
|
838 |
+
assert x is None
|
839 |
+
x = self.const.repeat(w.shape[0], 1, 1, 1)
|
840 |
+
else:
|
841 |
+
weight = self.weight
|
842 |
+
if self.wscale != 1.0:
|
843 |
+
weight = weight * self.wscale
|
844 |
+
|
845 |
+
if self.scale_factor > 1 and self.fused_scale:
|
846 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0)
|
847 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
848 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
|
849 |
+
x = F.conv_transpose2d(x,
|
850 |
+
weight=weight.transpose(0, 1),
|
851 |
+
bias=None,
|
852 |
+
stride=self.stride,
|
853 |
+
padding=self.padding)
|
854 |
+
else:
|
855 |
+
if self.scale_factor > 1:
|
856 |
+
up = self.scale_factor
|
857 |
+
x = F.interpolate(x, scale_factor=up, mode='nearest')
|
858 |
+
x = F.conv2d(x,
|
859 |
+
weight=weight,
|
860 |
+
bias=None,
|
861 |
+
stride=self.stride,
|
862 |
+
padding=self.padding)
|
863 |
+
|
864 |
+
if self.scale_factor > 1:
|
865 |
+
# Disable `autocast` for customized autograd function.
|
866 |
+
# Please check reference:
|
867 |
+
# https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd-functions
|
868 |
+
with autocast(enabled=False):
|
869 |
+
f = self.filter.repeat(self.out_channels, 1, 1, 1)
|
870 |
+
x = Blur.apply(x.float(), f) # Always use FP32.
|
871 |
+
|
872 |
+
# Prepare noise.
|
873 |
+
noise_mode = noise_mode.lower()
|
874 |
+
if self.noise_type != 'none' and noise_mode != 'none':
|
875 |
+
if noise_mode == 'random':
|
876 |
+
noise = torch.randn(
|
877 |
+
(x.shape[0], *self.noise.shape[1:]), device=x.device)
|
878 |
+
elif noise_mode == 'const':
|
879 |
+
noise = self.noise
|
880 |
+
else:
|
881 |
+
raise ValueError(f'Unknown noise mode `{noise_mode}`!')
|
882 |
+
x = x + noise * self.noise_strength
|
883 |
+
|
884 |
+
if self.bias is not None:
|
885 |
+
bias = self.bias
|
886 |
+
if self.bscale != 1.0:
|
887 |
+
bias = bias * self.bscale
|
888 |
+
x = x + bias.reshape(1, self.out_channels, 1, 1)
|
889 |
+
|
890 |
+
if self.activation_type == 'linear':
|
891 |
+
pass
|
892 |
+
elif self.activation_type == 'relu':
|
893 |
+
x = F.relu(x, inplace=True)
|
894 |
+
elif self.activation_type == 'lrelu':
|
895 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
|
896 |
+
else:
|
897 |
+
raise NotImplementedError(f'Not implemented activation type '
|
898 |
+
f'`{self.activation_type}`!')
|
899 |
+
|
900 |
+
if not self.use_style:
|
901 |
+
return x
|
902 |
+
|
903 |
+
# Instance normalization.
|
904 |
+
x = x - x.mean(dim=(2, 3), keepdim=True)
|
905 |
+
scale = (x.square().mean(dim=(2, 3), keepdim=True) + self.eps).rsqrt()
|
906 |
+
x = x * scale
|
907 |
+
# Style modulation.
|
908 |
+
style = self.forward_style(w)
|
909 |
+
style_split = style.unsqueeze(2).unsqueeze(3).chunk(2, dim=1)
|
910 |
+
x = x * (style_split[0] + 1) + style_split[1]
|
911 |
+
|
912 |
+
return x, style
|
913 |
+
|
914 |
+
|
915 |
+
class DenseLayer(nn.Module):
|
916 |
+
"""Implements the dense layer."""
|
917 |
+
|
918 |
+
def __init__(self,
|
919 |
+
in_channels,
|
920 |
+
out_channels,
|
921 |
+
add_bias,
|
922 |
+
use_wscale,
|
923 |
+
wscale_gain,
|
924 |
+
lr_mul,
|
925 |
+
activation_type):
|
926 |
+
"""Initializes with layer settings.
|
927 |
+
|
928 |
+
Args:
|
929 |
+
in_channels: Number of channels of the input tensor.
|
930 |
+
out_channels: Number of channels of the output tensor.
|
931 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
932 |
+
use_wscale: Whether to use weight scaling.
|
933 |
+
wscale_gain: Gain factor for weight scaling.
|
934 |
+
lr_mul: Learning multiplier for both weight and bias.
|
935 |
+
activation_type: Type of activation.
|
936 |
+
"""
|
937 |
+
super().__init__()
|
938 |
+
self.in_channels = in_channels
|
939 |
+
self.out_channels = out_channels
|
940 |
+
self.add_bias = add_bias
|
941 |
+
self.use_wscale = use_wscale
|
942 |
+
self.wscale_gain = wscale_gain
|
943 |
+
self.lr_mul = lr_mul
|
944 |
+
self.activation_type = activation_type
|
945 |
+
|
946 |
+
weight_shape = (out_channels, in_channels)
|
947 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
948 |
+
if use_wscale:
|
949 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
950 |
+
self.wscale = wscale * lr_mul
|
951 |
+
else:
|
952 |
+
self.weight = nn.Parameter(
|
953 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
954 |
+
self.wscale = lr_mul
|
955 |
+
|
956 |
+
if add_bias:
|
957 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
958 |
+
self.bscale = lr_mul
|
959 |
+
else:
|
960 |
+
self.bias = None
|
961 |
+
|
962 |
+
assert activation_type in ['linear', 'relu', 'lrelu']
|
963 |
+
|
964 |
+
def extra_repr(self):
|
965 |
+
return (f'in_ch={self.in_channels}, '
|
966 |
+
f'out_ch={self.out_channels}, '
|
967 |
+
f'wscale_gain={self.wscale_gain:.3f}, '
|
968 |
+
f'bias={self.add_bias}, '
|
969 |
+
f'lr_mul={self.lr_mul:.3f}, '
|
970 |
+
f'act={self.activation_type}')
|
971 |
+
|
972 |
+
def forward(self, x):
|
973 |
+
if x.ndim != 2:
|
974 |
+
x = x.flatten(start_dim=1)
|
975 |
+
|
976 |
+
weight = self.weight
|
977 |
+
if self.wscale != 1.0:
|
978 |
+
weight = weight * self.wscale
|
979 |
+
bias = None
|
980 |
+
if self.bias is not None:
|
981 |
+
bias = self.bias
|
982 |
+
if self.bscale != 1.0:
|
983 |
+
bias = bias * self.bscale
|
984 |
+
|
985 |
+
x = F.linear(x, weight=weight, bias=bias)
|
986 |
+
|
987 |
+
if self.activation_type == 'linear':
|
988 |
+
pass
|
989 |
+
elif self.activation_type == 'relu':
|
990 |
+
x = F.relu(x, inplace=True)
|
991 |
+
elif self.activation_type == 'lrelu':
|
992 |
+
x = F.leaky_relu(x, negative_slope=0.2, inplace=True)
|
993 |
+
else:
|
994 |
+
raise NotImplementedError(f'Not implemented activation type '
|
995 |
+
f'`{self.activation_type}`!')
|
996 |
+
|
997 |
+
return x
|
998 |
+
|
999 |
+
# pylint: enable=missing-function-docstring
|
models/test.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Unit test for loading pre-trained models.
|
3 |
+
|
4 |
+
Basically, this file tests whether the perceptual model (VGG16) and the
|
5 |
+
inception model (InceptionV3), which are commonly used for loss computation and
|
6 |
+
evaluation, have the expected behavior after loading pre-trained weights. In
|
7 |
+
particular, we compare with the models from repo
|
8 |
+
|
9 |
+
https://github.com/NVlabs/stylegan2-ada-pytorch
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from models import build_model
|
15 |
+
from utils.misc import download_url
|
16 |
+
|
17 |
+
__all__ = ['test_model']
|
18 |
+
|
19 |
+
_BATCH_SIZE = 4
|
20 |
+
# pylint: disable=line-too-long
|
21 |
+
_PERCEPTUAL_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
22 |
+
_INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
23 |
+
# pylint: enable=line-too-long
|
24 |
+
|
25 |
+
|
26 |
+
def test_model():
|
27 |
+
"""Collects all model tests."""
|
28 |
+
torch.backends.cudnn.enabled = True
|
29 |
+
torch.backends.cudnn.allow_tf32 = False
|
30 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
31 |
+
torch.backends.cudnn.benchmark = False
|
32 |
+
torch.backends.cudnn.deterministic = True
|
33 |
+
print('========== Start Model Test ==========')
|
34 |
+
test_perceptual()
|
35 |
+
test_inception()
|
36 |
+
print('========== Finish Model Test ==========')
|
37 |
+
|
38 |
+
|
39 |
+
def test_perceptual():
|
40 |
+
"""Test the perceptual model."""
|
41 |
+
print('===== Testing Perceptual Model =====')
|
42 |
+
|
43 |
+
print('Build test model.')
|
44 |
+
model = build_model('PerceptualModel',
|
45 |
+
use_torchvision=False,
|
46 |
+
no_top=False,
|
47 |
+
enable_lpips=True)
|
48 |
+
|
49 |
+
print('Build reference model.')
|
50 |
+
ref_model_path, _, = download_url(_PERCEPTUAL_URL)
|
51 |
+
with open(ref_model_path, 'rb') as f:
|
52 |
+
ref_model = torch.jit.load(f).eval().cuda()
|
53 |
+
|
54 |
+
print('Test performance: ')
|
55 |
+
for size in [224, 128, 256, 512, 1024]:
|
56 |
+
raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
|
57 |
+
raw_img_comp = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
|
58 |
+
|
59 |
+
# The test model requires input images to have range [-1, 1].
|
60 |
+
img = raw_img.to(torch.float32).cuda() / 127.5 - 1
|
61 |
+
img_comp = raw_img_comp.to(torch.float32).cuda() / 127.5 - 1
|
62 |
+
feat = model(img, resize_input=True, return_tensor='feature')
|
63 |
+
pred = model(img, resize_input=True, return_tensor='prediction')
|
64 |
+
lpips = model(img, img_comp, resize_input=False, return_tensor='lpips')
|
65 |
+
assert feat.shape == (_BATCH_SIZE, 4096)
|
66 |
+
assert pred.shape == (_BATCH_SIZE, 1000)
|
67 |
+
assert lpips.shape == (_BATCH_SIZE,)
|
68 |
+
|
69 |
+
# The reference model requires input images to have range [0, 255].
|
70 |
+
img = raw_img.to(torch.float32).cuda()
|
71 |
+
img_comp = raw_img_comp.to(torch.float32).cuda()
|
72 |
+
ref_feat = ref_model(img, resize_images=True, return_features=True)
|
73 |
+
ref_pred = ref_model(img, resize_images=True, return_features=False)
|
74 |
+
temp = ref_model(torch.cat([img, img_comp], dim=0),
|
75 |
+
resize_images=False, return_lpips=True).chunk(2)
|
76 |
+
ref_lpips = (temp[0] - temp[1]).square().sum(dim=1, keepdim=False)
|
77 |
+
assert ref_feat.shape == (_BATCH_SIZE, 4096)
|
78 |
+
assert ref_pred.shape == (_BATCH_SIZE, 1000)
|
79 |
+
assert ref_lpips.shape == (_BATCH_SIZE,)
|
80 |
+
|
81 |
+
print(f' Size {size}x{size}, feature (with resize):\n '
|
82 |
+
f'mean: {(feat - ref_feat).abs().mean().item():.3e}, '
|
83 |
+
f'max: {(feat - ref_feat).abs().max().item():.3e}, '
|
84 |
+
f'ref_mean: {ref_feat.abs().mean().item():.3e}, '
|
85 |
+
f'ref_max: {ref_feat.abs().max().item():.3e}.')
|
86 |
+
print(f' Size {size}x{size}, prediction (with resize):\n '
|
87 |
+
f'mean: {(pred - ref_pred).abs().mean().item():.3e}, '
|
88 |
+
f'max: {(pred - ref_pred).abs().max().item():.3e}, '
|
89 |
+
f'ref_mean: {ref_pred.abs().mean().item():.3e}, '
|
90 |
+
f'ref_max: {ref_pred.abs().max().item():.3e}.')
|
91 |
+
print(f' Size {size}x{size}, LPIPS (without resize):\n '
|
92 |
+
f'mean: {(lpips - ref_lpips).abs().mean().item():.3e}, '
|
93 |
+
f'max: {(lpips - ref_lpips).abs().max().item():.3e}, '
|
94 |
+
f'ref_mean: {ref_lpips.abs().mean().item():.3e}, '
|
95 |
+
f'ref_max: {ref_lpips.abs().max().item():.3e}.')
|
96 |
+
|
97 |
+
|
98 |
+
def test_inception():
|
99 |
+
"""Test the inception model."""
|
100 |
+
print('===== Testing Inception Model =====')
|
101 |
+
|
102 |
+
print('Build test model.')
|
103 |
+
model = build_model('InceptionModel', align_tf=True)
|
104 |
+
|
105 |
+
print('Build reference model.')
|
106 |
+
ref_model_path, _, = download_url(_INCEPTION_URL)
|
107 |
+
with open(ref_model_path, 'rb') as f:
|
108 |
+
ref_model = torch.jit.load(f).eval().cuda()
|
109 |
+
|
110 |
+
print('Test performance: ')
|
111 |
+
for size in [299, 128, 256, 512, 1024]:
|
112 |
+
raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
|
113 |
+
|
114 |
+
# The test model requires input images to have range [-1, 1].
|
115 |
+
img = raw_img.to(torch.float32).cuda() / 127.5 - 1
|
116 |
+
feat = model(img)
|
117 |
+
pred = model(img, output_predictions=True)
|
118 |
+
pred_nb = model(img, output_predictions=True, remove_logits_bias=True)
|
119 |
+
assert feat.shape == (_BATCH_SIZE, 2048)
|
120 |
+
assert pred.shape == (_BATCH_SIZE, 1008)
|
121 |
+
assert pred_nb.shape == (_BATCH_SIZE, 1008)
|
122 |
+
|
123 |
+
# The reference model requires input images to have range [0, 255].
|
124 |
+
img = raw_img.to(torch.float32).cuda()
|
125 |
+
ref_feat = ref_model(img, return_features=True)
|
126 |
+
ref_pred = ref_model(img)
|
127 |
+
ref_pred_nb = ref_model(img, no_output_bias=True)
|
128 |
+
assert ref_feat.shape == (_BATCH_SIZE, 2048)
|
129 |
+
assert ref_pred.shape == (_BATCH_SIZE, 1008)
|
130 |
+
assert ref_pred_nb.shape == (_BATCH_SIZE, 1008)
|
131 |
+
|
132 |
+
print(f' Size {size}x{size}, feature:\n '
|
133 |
+
f'mean: {(feat - ref_feat).abs().mean().item():.3e}, '
|
134 |
+
f'max: {(feat - ref_feat).abs().max().item():.3e}, '
|
135 |
+
f'ref_mean: {ref_feat.abs().mean().item():.3e}, '
|
136 |
+
f'ref_max: {ref_feat.abs().max().item():.3e}.')
|
137 |
+
print(f' Size {size}x{size}, prediction:\n '
|
138 |
+
f'mean: {(pred - ref_pred).abs().mean().item():.3e}, '
|
139 |
+
f'max: {(pred - ref_pred).abs().max().item():.3e}, '
|
140 |
+
f'ref_mean: {ref_pred.abs().mean().item():.3e}, '
|
141 |
+
f'ref_max: {ref_pred.abs().max().item():.3e}.')
|
142 |
+
print(f' Size {size}x{size}, prediction (without bias):\n '
|
143 |
+
f'mean: {(pred_nb - ref_pred_nb).abs().mean().item():.3e}, '
|
144 |
+
f'max: {(pred_nb - ref_pred_nb).abs().max().item():.3e}, '
|
145 |
+
f'ref_mean: {ref_pred_nb.abs().mean().item():.3e}, '
|
146 |
+
f'ref_max: {ref_pred_nb.abs().max().item():.3e}.')
|
models/utils/__init__.py
ADDED
File without changes
|
models/utils/ops.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains operators for neural networks."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
__all__ = ['all_gather']
|
8 |
+
|
9 |
+
|
10 |
+
def all_gather(tensor):
|
11 |
+
"""Gathers tensor from all devices and executes averaging."""
|
12 |
+
if not dist.is_initialized():
|
13 |
+
return tensor
|
14 |
+
|
15 |
+
world_size = dist.get_world_size()
|
16 |
+
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
|
17 |
+
dist.all_gather(tensor_list, tensor, async_op=False)
|
18 |
+
return torch.stack(tensor_list, dim=0).mean(dim=0)
|
requirements/convert.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.8.1
|
2 |
+
tensorflow-gpu==1.15
|
3 |
+
ninja==1.10.2
|
4 |
+
scikit-video==1.1.11
|
5 |
+
pillow==9.0.0
|
6 |
+
opencv-python-headless==4.5.5.62
|
7 |
+
requests
|
8 |
+
bs4
|
9 |
+
tqdm
|
10 |
+
rich
|
11 |
+
easydict
|
requirements/develop.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
bpytop # Monitor system resources.
|
2 |
+
gpustat # Monitor GPU usage.
|
3 |
+
pylint # Check coding style.
|
requirements/minimal.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.8.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
|
2 |
+
torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html
|
3 |
+
tensorboard==2.7.0
|
4 |
+
torch-tb-profiler==0.3.1
|
5 |
+
ninja==1.10.2
|
6 |
+
numpy==1.21.5
|
7 |
+
scipy==1.7.3
|
8 |
+
scikit-learn==1.0.2
|
9 |
+
scikit-video==1.1.11
|
10 |
+
pillow==9.0.0
|
11 |
+
opencv-python-headless==4.5.5.62
|
12 |
+
requests
|
13 |
+
bs4
|
14 |
+
tqdm
|
15 |
+
rich
|
16 |
+
click
|
17 |
+
cloup
|
18 |
+
psutil
|
19 |
+
easydict
|
20 |
+
lmdb
|
21 |
+
matplotlib
|
synthesis.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Script that synthesizes images with pre-trained models.
|
3 |
+
|
4 |
+
Support StyleGAN2 and StyleGAN3.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
from tqdm import tqdm
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from models import build_model
|
14 |
+
from utils.visualizers.html_visualizer import HtmlVisualizer
|
15 |
+
from utils.image_utils import save_image, resize_image
|
16 |
+
from utils.image_utils import postprocess_image
|
17 |
+
from utils.custom_utils import to_numpy
|
18 |
+
|
19 |
+
|
20 |
+
def parse_args():
|
21 |
+
"""Parses arguments."""
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
group = parser.add_argument_group('General options.')
|
24 |
+
group.add_argument('weight_path', type=str,
|
25 |
+
help='Weight path to the pre-trained model.')
|
26 |
+
group.add_argument('--save_dir', type=str, default=None,
|
27 |
+
help='Directory to save the results. If not specified, '
|
28 |
+
'the results will be saved to '
|
29 |
+
'`work_dirs/{TASK_SPECIFIC}/` by default.')
|
30 |
+
group.add_argument('--job', type=str, default='synthesize',
|
31 |
+
help='Name for the job. (default: synthesize)')
|
32 |
+
group.add_argument('--seed', type=int, default=4,
|
33 |
+
help='Seed for sampling. (default: 4)')
|
34 |
+
group.add_argument('--nums', type=int, default=100,
|
35 |
+
help='Number of samples to synthesized. (default: 100)')
|
36 |
+
group.add_argument('--img_size', type=int, default=1024,
|
37 |
+
help='Size of the synthesized images. (default: 1024)')
|
38 |
+
group.add_argument('--vis_size', type=int, default=256,
|
39 |
+
help='Size of the visualize images. (default: 256)')
|
40 |
+
group.add_argument('--w_dim', type=int, default=512,
|
41 |
+
help='Dimension of the latent w. (default: 512)')
|
42 |
+
group.add_argument('--batch_size', type=int, default=4,
|
43 |
+
help='Batch size. (default: 4)')
|
44 |
+
group.add_argument('--save_jpg', action='store_true', default=False,
|
45 |
+
help='Whether to save raw image. (default: False)')
|
46 |
+
group.add_argument('-d', '--data_name', type=str, default='ffhq',
|
47 |
+
help='Name of the datasets. (default: ffhq)')
|
48 |
+
group.add_argument('--latent_path', type=str, default='',
|
49 |
+
help='Path to the given latent codes. (default: None)')
|
50 |
+
group.add_argument('--trunc_psi', type=float, default=0.7,
|
51 |
+
help='Psi factor used for truncation. (default: 0.7)')
|
52 |
+
group.add_argument('--trunc_layers', type=int, default=8,
|
53 |
+
help='Number of layers to perform truncation.'
|
54 |
+
' (default: 8)')
|
55 |
+
|
56 |
+
group = parser.add_argument_group('StyleGAN2')
|
57 |
+
group.add_argument('--stylegan2', action='store_true',
|
58 |
+
help='Whether or not using StyleGAN2. (default: False)')
|
59 |
+
group.add_argument('--scale_stylegan2', type=float, default=1.0,
|
60 |
+
help='Scale for the number of channel fro stylegan2.')
|
61 |
+
group.add_argument('--randomize_noise', type=str, default='const',
|
62 |
+
help='Noise type when synthesizing. (const or random)')
|
63 |
+
|
64 |
+
group = parser.add_argument_group('StyleGAN3')
|
65 |
+
group.add_argument('--stylegan3', action='store_true',
|
66 |
+
help='Whether or not using StyleGAN3. (default: False)')
|
67 |
+
group.add_argument('--cfg', type=str, default='T',
|
68 |
+
help='Config of the stylegan3 (T/R).')
|
69 |
+
group.add_argument('--scale_stylegan3r', type=float, default=2.0,
|
70 |
+
help='Scale for the number of channel for stylegan3 R.')
|
71 |
+
group.add_argument('--scale_stylegan3t', type=float, default=1.0,
|
72 |
+
help='Scale for the number of channel for stylegan3 T.')
|
73 |
+
group.add_argument('--tx', type=float, default=0,
|
74 |
+
help='Translate X-coordinate. (default: 0.0)')
|
75 |
+
group.add_argument('--ty', type=float, default=0,
|
76 |
+
help='Translate Y-coordinate. (default: 0.0)')
|
77 |
+
group.add_argument('--rotate', type=float, default=0,
|
78 |
+
help='Rotation angle in degrees. (default: 0)')
|
79 |
+
return parser.parse_args()
|
80 |
+
|
81 |
+
|
82 |
+
def main():
|
83 |
+
"""Main function."""
|
84 |
+
args = parse_args()
|
85 |
+
# Parse model configuration.
|
86 |
+
assert (args.stylegan2 and not args.stylegan3) or \
|
87 |
+
(not args.stylegan2 and args.stylegan3)
|
88 |
+
job_disc = ''
|
89 |
+
if args.stylegan2:
|
90 |
+
config = dict(model_type='StyleGAN2Generator',
|
91 |
+
resolution=args.img_size,
|
92 |
+
w_dim=args.w_dim,
|
93 |
+
fmaps_base=int(args.scale_stylegan2 * (32 << 10)),
|
94 |
+
fmaps_max=512,)
|
95 |
+
job_disc += 'stylegan2'
|
96 |
+
else:
|
97 |
+
if args.stylegan3 and args.cfg == 'R':
|
98 |
+
config = dict(model_type='StyleGAN3Generator',
|
99 |
+
resolution=args.img_size,
|
100 |
+
w_dim=args.w_dim,
|
101 |
+
fmaps_base=int(args.scale_stylegan3r * (32 << 10)),
|
102 |
+
fmaps_max=1024,
|
103 |
+
use_radial_filter=True,)
|
104 |
+
job_disc += 'stylegan3r'
|
105 |
+
elif args.stylegan3 and args.cfg == 'T':
|
106 |
+
config = dict(model_type='StyleGAN3Generator',
|
107 |
+
resolution=args.img_size,
|
108 |
+
w_dim=args.w_dim,
|
109 |
+
fmaps_base=int(args.scale_stylegan3t * (32 << 10)),
|
110 |
+
fmaps_max=512,
|
111 |
+
use_radial_filter=False,
|
112 |
+
kernel_size=3,)
|
113 |
+
job_disc += 'stylegan3t'
|
114 |
+
else:
|
115 |
+
raise TypeError(f'StyleGAN3 config type error, need `R/T`,'
|
116 |
+
f' but got {args.cfg} instead.')
|
117 |
+
|
118 |
+
# Get work directory and job name.
|
119 |
+
save_dir = args.save_dir or f'work_dirs/{args.job}/{args.data_name}'
|
120 |
+
os.makedirs(save_dir, exist_ok=True)
|
121 |
+
job_name = f'seed_{args.seed}_num_{args.nums}_{job_disc}'
|
122 |
+
os.makedirs(f'{save_dir}/{job_name}', exist_ok=True)
|
123 |
+
|
124 |
+
# Build generation and get synthesis kwargs.
|
125 |
+
print('Building generator...')
|
126 |
+
generator = build_model(**config)
|
127 |
+
synthesis_kwargs = dict(trunc_psi=args.trunc_psi,
|
128 |
+
trunc_layers=args.trunc_layers,)
|
129 |
+
# Load pre-trained weights.
|
130 |
+
checkpoint_path = args.weight_path
|
131 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
132 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')['models']
|
133 |
+
if 'generator_smooth' in checkpoint:
|
134 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
135 |
+
else:
|
136 |
+
generator.load_state_dict(checkpoint['generator'])
|
137 |
+
generator = generator.eval().cuda()
|
138 |
+
print('Finish loading checkpoint.')
|
139 |
+
|
140 |
+
np.random.seed(args.seed)
|
141 |
+
torch.manual_seed(args.seed)
|
142 |
+
if os.path.exists(args.latent_path):
|
143 |
+
latent_zs = np.load(args.latent_path)
|
144 |
+
latent_zs = latent_zs[:args.nums]
|
145 |
+
else:
|
146 |
+
latent_zs = np.random.randn(args.nums, generator.z_dim)
|
147 |
+
num_images = latent_zs.shape[0]
|
148 |
+
latent_zs = torch.from_numpy(latent_zs.astype(np.float32))
|
149 |
+
html = HtmlVisualizer(grid_size=num_images)
|
150 |
+
print(f'Synthesizing {num_images} images ...')
|
151 |
+
latent_ws = []
|
152 |
+
for batch_idx in tqdm(range(0, num_images, args.batch_size)):
|
153 |
+
latent_z = latent_zs[batch_idx:batch_idx + args.batch_size]
|
154 |
+
latent_z = latent_z.cuda()
|
155 |
+
with torch.no_grad():
|
156 |
+
g_outputs = generator(latent_z, **synthesis_kwargs)
|
157 |
+
g_image = to_numpy(g_outputs['image'])
|
158 |
+
images = postprocess_image(g_image)
|
159 |
+
for idx in range(images.shape[0]):
|
160 |
+
sub_idx = batch_idx + idx
|
161 |
+
img = images[idx]
|
162 |
+
row_idx, col_idx = divmod(sub_idx, html.num_cols)
|
163 |
+
image = resize_image(img, (args.vis_size, args.vis_size))
|
164 |
+
html.set_cell(row_idx, col_idx, image=image,
|
165 |
+
text=f'Sample {sub_idx:06d}')
|
166 |
+
if args.save_jpg:
|
167 |
+
save_path = f'{save_dir}/{job_name}/{sub_idx:06d}.jpg'
|
168 |
+
save_image(save_path, img)
|
169 |
+
latent_ws.append(to_numpy(g_outputs['wp']))
|
170 |
+
latent_ws = np.concatenate(latent_ws, axis=0)
|
171 |
+
print(f'shape of the latent code: {latent_ws.shape}')
|
172 |
+
np.save(f'{save_dir}/{job_name}/latent_codes.npy', latent_ws)
|
173 |
+
html.save(f'{save_dir}/{job_name}.html')
|
174 |
+
print(f'Finish synthesizing {num_images} samples.')
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == '__main__':
|
178 |
+
main()
|
third_party/__init__.py
ADDED
File without changes
|
third_party/stylegan2_official_ops/README.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Operators for StyleGAN2
|
2 |
+
|
3 |
+
All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
|
4 |
+
|
5 |
+
- `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
|
6 |
+
- `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
|
7 |
+
- `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
|
8 |
+
- `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
|
9 |
+
- `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
|
10 |
+
- `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
|
11 |
+
- `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
|
12 |
+
- `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
|
13 |
+
- `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
|
14 |
+
|
15 |
+
We make following slight modifications beyond disabling some lint warnings:
|
16 |
+
|
17 |
+
- Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
|
18 |
+
- Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators.
|
19 |
+
- Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
|
20 |
+
- Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
|
21 |
+
- Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
|
22 |
+
- Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
|
23 |
+
- Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
|
24 |
+
- Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
|
25 |
+
- Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
|
26 |
+
- Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
|
27 |
+
|
28 |
+
Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.
|
third_party/stylegan2_official_ops/__init__.py
ADDED
File without changes
|
third_party/stylegan2_official_ops/bias_act.cpp
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "bias_act.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
17 |
+
{
|
18 |
+
if (x.dim() != y.dim())
|
19 |
+
return false;
|
20 |
+
for (int64_t i = 0; i < x.dim(); i++)
|
21 |
+
{
|
22 |
+
if (x.size(i) != y.size(i))
|
23 |
+
return false;
|
24 |
+
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
25 |
+
return false;
|
26 |
+
}
|
27 |
+
return true;
|
28 |
+
}
|
29 |
+
|
30 |
+
//------------------------------------------------------------------------
|
31 |
+
|
32 |
+
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
33 |
+
{
|
34 |
+
// Validate arguments.
|
35 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
36 |
+
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
37 |
+
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
38 |
+
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
39 |
+
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
40 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
41 |
+
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
42 |
+
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
43 |
+
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
44 |
+
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
45 |
+
|
46 |
+
// Validate layout.
|
47 |
+
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
48 |
+
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
49 |
+
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
50 |
+
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
51 |
+
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
52 |
+
|
53 |
+
// Create output tensor.
|
54 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
55 |
+
torch::Tensor y = torch::empty_like(x);
|
56 |
+
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
57 |
+
|
58 |
+
// Initialize CUDA kernel parameters.
|
59 |
+
bias_act_kernel_params p;
|
60 |
+
p.x = x.data_ptr();
|
61 |
+
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
62 |
+
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
63 |
+
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
64 |
+
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
65 |
+
p.y = y.data_ptr();
|
66 |
+
p.grad = grad;
|
67 |
+
p.act = act;
|
68 |
+
p.alpha = alpha;
|
69 |
+
p.gain = gain;
|
70 |
+
p.clamp = clamp;
|
71 |
+
p.sizeX = (int)x.numel();
|
72 |
+
p.sizeB = (int)b.numel();
|
73 |
+
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
74 |
+
|
75 |
+
// Choose CUDA kernel.
|
76 |
+
void* kernel;
|
77 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
78 |
+
{
|
79 |
+
kernel = choose_bias_act_kernel<scalar_t>(p);
|
80 |
+
});
|
81 |
+
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
82 |
+
|
83 |
+
// Launch CUDA kernel.
|
84 |
+
p.loopX = 4;
|
85 |
+
int blockSize = 4 * 32;
|
86 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
87 |
+
void* args[] = {&p};
|
88 |
+
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
89 |
+
return y;
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
|
94 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
95 |
+
{
|
96 |
+
m.def("bias_act", &bias_act);
|
97 |
+
}
|
98 |
+
|
99 |
+
//------------------------------------------------------------------------
|
third_party/stylegan2_official_ops/bias_act.cu
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "bias_act.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
//------------------------------------------------------------------------
|
21 |
+
// CUDA kernel.
|
22 |
+
|
23 |
+
template <class T, int A>
|
24 |
+
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
25 |
+
{
|
26 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
27 |
+
int G = p.grad;
|
28 |
+
scalar_t alpha = (scalar_t)p.alpha;
|
29 |
+
scalar_t gain = (scalar_t)p.gain;
|
30 |
+
scalar_t clamp = (scalar_t)p.clamp;
|
31 |
+
scalar_t one = (scalar_t)1;
|
32 |
+
scalar_t two = (scalar_t)2;
|
33 |
+
scalar_t expRange = (scalar_t)80;
|
34 |
+
scalar_t halfExpRange = (scalar_t)40;
|
35 |
+
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
36 |
+
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
37 |
+
|
38 |
+
// Loop over elements.
|
39 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
40 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
41 |
+
{
|
42 |
+
// Load.
|
43 |
+
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
44 |
+
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
45 |
+
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
46 |
+
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
47 |
+
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
48 |
+
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
49 |
+
scalar_t y = 0;
|
50 |
+
|
51 |
+
// Apply bias.
|
52 |
+
((G == 0) ? x : xref) += b;
|
53 |
+
|
54 |
+
// linear
|
55 |
+
if (A == 1)
|
56 |
+
{
|
57 |
+
if (G == 0) y = x;
|
58 |
+
if (G == 1) y = x;
|
59 |
+
}
|
60 |
+
|
61 |
+
// relu
|
62 |
+
if (A == 2)
|
63 |
+
{
|
64 |
+
if (G == 0) y = (x > 0) ? x : 0;
|
65 |
+
if (G == 1) y = (yy > 0) ? x : 0;
|
66 |
+
}
|
67 |
+
|
68 |
+
// lrelu
|
69 |
+
if (A == 3)
|
70 |
+
{
|
71 |
+
if (G == 0) y = (x > 0) ? x : x * alpha;
|
72 |
+
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
73 |
+
}
|
74 |
+
|
75 |
+
// tanh
|
76 |
+
if (A == 4)
|
77 |
+
{
|
78 |
+
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
79 |
+
if (G == 1) y = x * (one - yy * yy);
|
80 |
+
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
81 |
+
}
|
82 |
+
|
83 |
+
// sigmoid
|
84 |
+
if (A == 5)
|
85 |
+
{
|
86 |
+
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
87 |
+
if (G == 1) y = x * yy * (one - yy);
|
88 |
+
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
89 |
+
}
|
90 |
+
|
91 |
+
// elu
|
92 |
+
if (A == 6)
|
93 |
+
{
|
94 |
+
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
95 |
+
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
96 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
97 |
+
}
|
98 |
+
|
99 |
+
// selu
|
100 |
+
if (A == 7)
|
101 |
+
{
|
102 |
+
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
103 |
+
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
104 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
105 |
+
}
|
106 |
+
|
107 |
+
// softplus
|
108 |
+
if (A == 8)
|
109 |
+
{
|
110 |
+
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
111 |
+
if (G == 1) y = x * (one - exp(-yy));
|
112 |
+
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
113 |
+
}
|
114 |
+
|
115 |
+
// swish
|
116 |
+
if (A == 9)
|
117 |
+
{
|
118 |
+
if (G == 0)
|
119 |
+
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
120 |
+
else
|
121 |
+
{
|
122 |
+
scalar_t c = exp(xref);
|
123 |
+
scalar_t d = c + one;
|
124 |
+
if (G == 1)
|
125 |
+
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
126 |
+
else
|
127 |
+
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
128 |
+
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// Apply gain.
|
133 |
+
y *= gain * dy;
|
134 |
+
|
135 |
+
// Clamp.
|
136 |
+
if (clamp >= 0)
|
137 |
+
{
|
138 |
+
if (G == 0)
|
139 |
+
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
140 |
+
else
|
141 |
+
y = (yref > -clamp & yref < clamp) ? y : 0;
|
142 |
+
}
|
143 |
+
|
144 |
+
// Store.
|
145 |
+
((T*)p.y)[xi] = (T)y;
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
//------------------------------------------------------------------------
|
150 |
+
// CUDA kernel selection.
|
151 |
+
|
152 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
153 |
+
{
|
154 |
+
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
155 |
+
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
156 |
+
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
157 |
+
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
158 |
+
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
159 |
+
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
160 |
+
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
161 |
+
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
162 |
+
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
163 |
+
return NULL;
|
164 |
+
}
|
165 |
+
|
166 |
+
//------------------------------------------------------------------------
|
167 |
+
// Template specializations.
|
168 |
+
|
169 |
+
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
170 |
+
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
171 |
+
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
172 |
+
|
173 |
+
//------------------------------------------------------------------------
|
third_party/stylegan2_official_ops/bias_act.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
//------------------------------------------------------------------------
|
10 |
+
// CUDA kernel parameters.
|
11 |
+
|
12 |
+
struct bias_act_kernel_params
|
13 |
+
{
|
14 |
+
const void* x; // [sizeX]
|
15 |
+
const void* b; // [sizeB] or NULL
|
16 |
+
const void* xref; // [sizeX] or NULL
|
17 |
+
const void* yref; // [sizeX] or NULL
|
18 |
+
const void* dy; // [sizeX] or NULL
|
19 |
+
void* y; // [sizeX]
|
20 |
+
|
21 |
+
int grad;
|
22 |
+
int act;
|
23 |
+
float alpha;
|
24 |
+
float gain;
|
25 |
+
float clamp;
|
26 |
+
|
27 |
+
int sizeX;
|
28 |
+
int sizeB;
|
29 |
+
int stepB;
|
30 |
+
int loopX;
|
31 |
+
};
|
32 |
+
|
33 |
+
//------------------------------------------------------------------------
|
34 |
+
// CUDA kernel selection.
|
35 |
+
|
36 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
37 |
+
|
38 |
+
//------------------------------------------------------------------------
|
third_party/stylegan2_official_ops/bias_act.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""Custom ops to fuse bias and activation as one operator, which is efficient.
|
12 |
+
|
13 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
14 |
+
"""
|
15 |
+
|
16 |
+
# pylint: disable=line-too-long
|
17 |
+
# pylint: disable=missing-class-docstring
|
18 |
+
# pylint: disable=global-statement
|
19 |
+
# pylint: disable=bare-except
|
20 |
+
|
21 |
+
import os
|
22 |
+
import warnings
|
23 |
+
import traceback
|
24 |
+
from easydict import EasyDict
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
|
28 |
+
from . import custom_ops
|
29 |
+
from . import misc
|
30 |
+
|
31 |
+
#----------------------------------------------------------------------------
|
32 |
+
|
33 |
+
activation_funcs = {
|
34 |
+
'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
35 |
+
'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
36 |
+
'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
37 |
+
'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
38 |
+
'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
39 |
+
'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
40 |
+
'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
41 |
+
'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
42 |
+
'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
43 |
+
}
|
44 |
+
|
45 |
+
#----------------------------------------------------------------------------
|
46 |
+
|
47 |
+
_inited = False
|
48 |
+
_plugin = None
|
49 |
+
_null_tensor = torch.empty([0])
|
50 |
+
|
51 |
+
def _init():
|
52 |
+
global _inited, _plugin
|
53 |
+
if not _inited:
|
54 |
+
_inited = True
|
55 |
+
sources = ['bias_act.cpp', 'bias_act.cu']
|
56 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
57 |
+
try:
|
58 |
+
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
59 |
+
except:
|
60 |
+
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
61 |
+
return _plugin is not None
|
62 |
+
|
63 |
+
#----------------------------------------------------------------------------
|
64 |
+
|
65 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
66 |
+
r"""Fused bias and activation function.
|
67 |
+
|
68 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
69 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
70 |
+
the fused op is considerably more efficient than performing the same calculation
|
71 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
72 |
+
but not third order gradients.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
x: Input activation tensor. Can be of any shape.
|
76 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
77 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
78 |
+
corresponding to `dim`.
|
79 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
80 |
+
The value of `dim` is ignored if `b` is not specified.
|
81 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
82 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
83 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
84 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
85 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
86 |
+
See `activation_funcs` for the default scaling of each activation function.
|
87 |
+
If unsure, consider specifying 1.
|
88 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
89 |
+
the clamping (default).
|
90 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
Tensor of the same shape and datatype as `x`.
|
94 |
+
"""
|
95 |
+
assert isinstance(x, torch.Tensor)
|
96 |
+
assert impl in ['ref', 'cuda']
|
97 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
98 |
+
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
99 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
100 |
+
|
101 |
+
#----------------------------------------------------------------------------
|
102 |
+
|
103 |
+
@misc.profiled_function
|
104 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
105 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
106 |
+
"""
|
107 |
+
assert isinstance(x, torch.Tensor)
|
108 |
+
assert clamp is None or clamp >= 0
|
109 |
+
spec = activation_funcs[act]
|
110 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
111 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
112 |
+
clamp = float(clamp if clamp is not None else -1)
|
113 |
+
|
114 |
+
# Add bias.
|
115 |
+
if b is not None:
|
116 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
117 |
+
assert 0 <= dim < x.ndim
|
118 |
+
assert b.shape[0] == x.shape[dim]
|
119 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
120 |
+
|
121 |
+
# Evaluate activation function.
|
122 |
+
alpha = float(alpha)
|
123 |
+
x = spec.func(x, alpha=alpha)
|
124 |
+
|
125 |
+
# Scale by gain.
|
126 |
+
gain = float(gain)
|
127 |
+
if gain != 1:
|
128 |
+
x = x * gain
|
129 |
+
|
130 |
+
# Clamp.
|
131 |
+
if clamp >= 0:
|
132 |
+
x = x.clamp(-clamp, clamp)
|
133 |
+
return x
|
134 |
+
|
135 |
+
#----------------------------------------------------------------------------
|
136 |
+
|
137 |
+
_bias_act_cuda_cache = dict()
|
138 |
+
|
139 |
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
140 |
+
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
141 |
+
"""
|
142 |
+
# Parse arguments.
|
143 |
+
assert clamp is None or clamp >= 0
|
144 |
+
spec = activation_funcs[act]
|
145 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
146 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
147 |
+
clamp = float(clamp if clamp is not None else -1)
|
148 |
+
|
149 |
+
# Lookup from cache.
|
150 |
+
key = (dim, act, alpha, gain, clamp)
|
151 |
+
if key in _bias_act_cuda_cache:
|
152 |
+
return _bias_act_cuda_cache[key]
|
153 |
+
|
154 |
+
# Forward op.
|
155 |
+
class BiasActCuda(torch.autograd.Function):
|
156 |
+
@staticmethod
|
157 |
+
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
158 |
+
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
159 |
+
x = x.contiguous(memory_format=ctx.memory_format)
|
160 |
+
b = b.contiguous() if b is not None else _null_tensor
|
161 |
+
y = x
|
162 |
+
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
163 |
+
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
164 |
+
ctx.save_for_backward(
|
165 |
+
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
166 |
+
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
167 |
+
y if 'y' in spec.ref else _null_tensor)
|
168 |
+
return y
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
172 |
+
dy = dy.contiguous(memory_format=ctx.memory_format)
|
173 |
+
x, b, y = ctx.saved_tensors
|
174 |
+
dx = None
|
175 |
+
db = None
|
176 |
+
|
177 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
178 |
+
dx = dy
|
179 |
+
if act != 'linear' or gain != 1 or clamp >= 0:
|
180 |
+
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
181 |
+
|
182 |
+
if ctx.needs_input_grad[1]:
|
183 |
+
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
184 |
+
|
185 |
+
return dx, db
|
186 |
+
|
187 |
+
# Backward op.
|
188 |
+
class BiasActCudaGrad(torch.autograd.Function):
|
189 |
+
@staticmethod
|
190 |
+
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
191 |
+
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
192 |
+
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
193 |
+
ctx.save_for_backward(
|
194 |
+
dy if spec.has_2nd_grad else _null_tensor,
|
195 |
+
x, b, y)
|
196 |
+
return dx
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
200 |
+
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
201 |
+
dy, x, b, y = ctx.saved_tensors
|
202 |
+
d_dy = None
|
203 |
+
d_x = None
|
204 |
+
d_b = None
|
205 |
+
d_y = None
|
206 |
+
|
207 |
+
if ctx.needs_input_grad[0]:
|
208 |
+
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
209 |
+
|
210 |
+
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
211 |
+
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
212 |
+
|
213 |
+
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
214 |
+
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
215 |
+
|
216 |
+
return d_dy, d_x, d_b, d_y
|
217 |
+
|
218 |
+
# Add to cache.
|
219 |
+
_bias_act_cuda_cache[key] = BiasActCuda
|
220 |
+
return BiasActCuda
|
221 |
+
|
222 |
+
#----------------------------------------------------------------------------
|
223 |
+
|
224 |
+
# pylint: enable=line-too-long
|
225 |
+
# pylint: enable=missing-class-docstring
|
226 |
+
# pylint: enable=global-statement
|
227 |
+
# pylint: enable=bare-except
|
third_party/stylegan2_official_ops/conv2d_gradfix.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""Custom replacement for convolution operators.
|
12 |
+
|
13 |
+
Operators in this file support arbitrarily high order gradients with zero
|
14 |
+
performance penalty. Please set `impl` as `cuda` to use faster customized
|
15 |
+
operators, OR as `ref` to use native `torch.nn.functional.conv2d` and
|
16 |
+
`torch.nn.functional.conv_transpose2d`.
|
17 |
+
|
18 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
19 |
+
"""
|
20 |
+
|
21 |
+
# pylint: disable=redefined-builtin
|
22 |
+
# pylint: disable=arguments-differ
|
23 |
+
# pylint: disable=protected-access
|
24 |
+
# pylint: disable=line-too-long
|
25 |
+
# pylint: disable=global-statement
|
26 |
+
# pylint: disable=missing-class-docstring
|
27 |
+
# pylint: disable=missing-function-docstring
|
28 |
+
|
29 |
+
import warnings
|
30 |
+
import contextlib
|
31 |
+
import torch
|
32 |
+
|
33 |
+
enabled = True # Enable the custom op by setting this to true.
|
34 |
+
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
35 |
+
|
36 |
+
@contextlib.contextmanager
|
37 |
+
def no_weight_gradients():
|
38 |
+
global weight_gradients_disabled
|
39 |
+
old = weight_gradients_disabled
|
40 |
+
weight_gradients_disabled = True
|
41 |
+
yield
|
42 |
+
weight_gradients_disabled = old
|
43 |
+
|
44 |
+
#----------------------------------------------------------------------------
|
45 |
+
|
46 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'):
|
47 |
+
if impl == 'cuda' and _should_use_custom_op(input):
|
48 |
+
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
49 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
50 |
+
|
51 |
+
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'):
|
52 |
+
if impl == 'cuda' and _should_use_custom_op(input):
|
53 |
+
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
54 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
55 |
+
|
56 |
+
#----------------------------------------------------------------------------
|
57 |
+
|
58 |
+
def _should_use_custom_op(input):
|
59 |
+
assert isinstance(input, torch.Tensor)
|
60 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
61 |
+
return False
|
62 |
+
if input.device.type != 'cuda':
|
63 |
+
return False
|
64 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
65 |
+
return True
|
66 |
+
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
67 |
+
return False
|
68 |
+
|
69 |
+
def _tuple_of_ints(xs, ndim):
|
70 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
71 |
+
assert len(xs) == ndim
|
72 |
+
assert all(isinstance(x, int) for x in xs)
|
73 |
+
return xs
|
74 |
+
|
75 |
+
#----------------------------------------------------------------------------
|
76 |
+
|
77 |
+
_conv2d_gradfix_cache = dict()
|
78 |
+
|
79 |
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
80 |
+
# Parse arguments.
|
81 |
+
ndim = 2
|
82 |
+
weight_shape = tuple(weight_shape)
|
83 |
+
stride = _tuple_of_ints(stride, ndim)
|
84 |
+
padding = _tuple_of_ints(padding, ndim)
|
85 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
86 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
87 |
+
|
88 |
+
# Lookup from cache.
|
89 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
90 |
+
if key in _conv2d_gradfix_cache:
|
91 |
+
return _conv2d_gradfix_cache[key]
|
92 |
+
|
93 |
+
# Validate arguments.
|
94 |
+
assert groups >= 1
|
95 |
+
assert len(weight_shape) == ndim + 2
|
96 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
97 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
98 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
99 |
+
if not transpose:
|
100 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
101 |
+
else: # transpose
|
102 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
103 |
+
|
104 |
+
# Helpers.
|
105 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
106 |
+
def calc_output_padding(input_shape, output_shape):
|
107 |
+
if transpose:
|
108 |
+
return [0, 0]
|
109 |
+
return [
|
110 |
+
input_shape[i + 2]
|
111 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
112 |
+
- (1 - 2 * padding[i])
|
113 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
114 |
+
for i in range(ndim)
|
115 |
+
]
|
116 |
+
|
117 |
+
# Forward & backward.
|
118 |
+
class Conv2d(torch.autograd.Function):
|
119 |
+
@staticmethod
|
120 |
+
def forward(ctx, input, weight, bias):
|
121 |
+
assert weight.shape == weight_shape
|
122 |
+
if not transpose:
|
123 |
+
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
124 |
+
else: # transpose
|
125 |
+
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
126 |
+
ctx.save_for_backward(input, weight)
|
127 |
+
return output
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def backward(ctx, grad_output):
|
131 |
+
input, weight = ctx.saved_tensors
|
132 |
+
grad_input = None
|
133 |
+
grad_weight = None
|
134 |
+
grad_bias = None
|
135 |
+
|
136 |
+
if ctx.needs_input_grad[0]:
|
137 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
138 |
+
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
|
139 |
+
assert grad_input.shape == input.shape
|
140 |
+
|
141 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
142 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
143 |
+
assert grad_weight.shape == weight_shape
|
144 |
+
|
145 |
+
if ctx.needs_input_grad[2]:
|
146 |
+
grad_bias = grad_output.sum([0, 2, 3])
|
147 |
+
|
148 |
+
return grad_input, grad_weight, grad_bias
|
149 |
+
|
150 |
+
# Gradient with respect to the weights.
|
151 |
+
class Conv2dGradWeight(torch.autograd.Function):
|
152 |
+
@staticmethod
|
153 |
+
def forward(ctx, grad_output, input):
|
154 |
+
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
155 |
+
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
156 |
+
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
157 |
+
assert grad_weight.shape == weight_shape
|
158 |
+
ctx.save_for_backward(grad_output, input)
|
159 |
+
return grad_weight
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def backward(ctx, grad2_grad_weight):
|
163 |
+
grad_output, input = ctx.saved_tensors
|
164 |
+
grad2_grad_output = None
|
165 |
+
grad2_input = None
|
166 |
+
|
167 |
+
if ctx.needs_input_grad[0]:
|
168 |
+
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
169 |
+
assert grad2_grad_output.shape == grad_output.shape
|
170 |
+
|
171 |
+
if ctx.needs_input_grad[1]:
|
172 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
173 |
+
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
174 |
+
assert grad2_input.shape == input.shape
|
175 |
+
|
176 |
+
return grad2_grad_output, grad2_input
|
177 |
+
|
178 |
+
_conv2d_gradfix_cache[key] = Conv2d
|
179 |
+
return Conv2d
|
180 |
+
|
181 |
+
#----------------------------------------------------------------------------
|
182 |
+
|
183 |
+
# pylint: enable=redefined-builtin
|
184 |
+
# pylint: enable=arguments-differ
|
185 |
+
# pylint: enable=protected-access
|
186 |
+
# pylint: enable=line-too-long
|
187 |
+
# pylint: enable=global-statement
|
188 |
+
# pylint: enable=missing-class-docstring
|
189 |
+
# pylint: enable=missing-function-docstring
|
third_party/stylegan2_official_ops/conv2d_resample.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""2D convolution with optional up/downsampling.
|
12 |
+
|
13 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
14 |
+
"""
|
15 |
+
|
16 |
+
# pylint: disable=line-too-long
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from . import misc
|
21 |
+
from . import conv2d_gradfix
|
22 |
+
from . import upfirdn2d
|
23 |
+
from .upfirdn2d import _parse_padding
|
24 |
+
from .upfirdn2d import _get_filter_size
|
25 |
+
|
26 |
+
#----------------------------------------------------------------------------
|
27 |
+
|
28 |
+
def _get_weight_shape(w):
|
29 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
30 |
+
shape = [int(sz) for sz in w.shape]
|
31 |
+
misc.assert_shape(w, shape)
|
32 |
+
return shape
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'):
|
37 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
38 |
+
"""
|
39 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
40 |
+
|
41 |
+
# Flip weight if requested.
|
42 |
+
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
43 |
+
w = w.flip([2, 3])
|
44 |
+
|
45 |
+
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
46 |
+
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
47 |
+
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
48 |
+
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
49 |
+
if out_channels <= 4 and groups == 1:
|
50 |
+
in_shape = x.shape
|
51 |
+
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
52 |
+
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
53 |
+
else:
|
54 |
+
x = x.to(memory_format=torch.contiguous_format)
|
55 |
+
w = w.to(memory_format=torch.contiguous_format)
|
56 |
+
x = conv2d_gradfix.conv2d(x, w, groups=groups, impl=impl)
|
57 |
+
return x.to(memory_format=torch.channels_last)
|
58 |
+
|
59 |
+
# Otherwise => execute using conv2d_gradfix.
|
60 |
+
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
61 |
+
return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl)
|
62 |
+
|
63 |
+
#----------------------------------------------------------------------------
|
64 |
+
|
65 |
+
@misc.profiled_function
|
66 |
+
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'):
|
67 |
+
r"""2D convolution with optional up/downsampling.
|
68 |
+
|
69 |
+
Padding is performed only once at the beginning, not between the operations.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
x: Input tensor of shape
|
73 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
74 |
+
w: Weight tensor of shape
|
75 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
76 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
77 |
+
calling upfirdn2d.setup_filter(). None = identity (default).
|
78 |
+
up: Integer upsampling factor (default: 1).
|
79 |
+
down: Integer downsampling factor (default: 1).
|
80 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
81 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
82 |
+
(default: 0).
|
83 |
+
groups: Split input channels into N groups (default: 1).
|
84 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
85 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
86 |
+
impl: Implementation mode of customized ops. 'ref' for native PyTorch
|
87 |
+
implementation, 'cuda' for `.cu` implementation
|
88 |
+
(default: 'cuda').
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
92 |
+
"""
|
93 |
+
# Validate arguments.
|
94 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
95 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
96 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
97 |
+
assert isinstance(up, int) and (up >= 1)
|
98 |
+
assert isinstance(down, int) and (down >= 1)
|
99 |
+
assert isinstance(groups, int) and (groups >= 1)
|
100 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
101 |
+
fw, fh = _get_filter_size(f)
|
102 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
103 |
+
|
104 |
+
# Adjust padding to account for up/downsampling.
|
105 |
+
if up > 1:
|
106 |
+
px0 += (fw + up - 1) // 2
|
107 |
+
px1 += (fw - up) // 2
|
108 |
+
py0 += (fh + up - 1) // 2
|
109 |
+
py1 += (fh - up) // 2
|
110 |
+
if down > 1:
|
111 |
+
px0 += (fw - down + 1) // 2
|
112 |
+
px1 += (fw - down) // 2
|
113 |
+
py0 += (fh - down + 1) // 2
|
114 |
+
py1 += (fh - down) // 2
|
115 |
+
|
116 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
117 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
118 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
|
119 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
|
120 |
+
return x
|
121 |
+
|
122 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
123 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
124 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
|
125 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
|
126 |
+
return x
|
127 |
+
|
128 |
+
# Fast path: downsampling only => use strided convolution.
|
129 |
+
if down > 1 and up == 1:
|
130 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
|
131 |
+
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl)
|
132 |
+
return x
|
133 |
+
|
134 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
135 |
+
if up > 1:
|
136 |
+
if groups == 1:
|
137 |
+
w = w.transpose(0, 1)
|
138 |
+
else:
|
139 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
140 |
+
w = w.transpose(1, 2)
|
141 |
+
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
142 |
+
px0 -= kw - 1
|
143 |
+
px1 -= kw - up
|
144 |
+
py0 -= kh - 1
|
145 |
+
py1 -= kh - up
|
146 |
+
pxt = max(min(-px0, -px1), 0)
|
147 |
+
pyt = max(min(-py0, -py1), 0)
|
148 |
+
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl)
|
149 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl)
|
150 |
+
if down > 1:
|
151 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
|
152 |
+
return x
|
153 |
+
|
154 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
155 |
+
if up == 1 and down == 1:
|
156 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
157 |
+
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl)
|
158 |
+
|
159 |
+
# Fallback: Generic reference implementation.
|
160 |
+
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
|
161 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
|
162 |
+
if down > 1:
|
163 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
|
164 |
+
return x
|
165 |
+
|
166 |
+
#----------------------------------------------------------------------------
|
167 |
+
|
168 |
+
# pylint: enable=line-too-long
|
third_party/stylegan2_official_ops/custom_ops.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""Utility functions to setup customized operators.
|
12 |
+
|
13 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
14 |
+
"""
|
15 |
+
|
16 |
+
# pylint: disable=line-too-long
|
17 |
+
# pylint: disable=missing-function-docstring
|
18 |
+
# pylint: disable=useless-suppression
|
19 |
+
# pylint: disable=inconsistent-quotes
|
20 |
+
|
21 |
+
import os
|
22 |
+
import glob
|
23 |
+
import importlib
|
24 |
+
import hashlib
|
25 |
+
import shutil
|
26 |
+
from pathlib import Path
|
27 |
+
|
28 |
+
import torch
|
29 |
+
from torch.utils.file_baton import FileBaton
|
30 |
+
import torch.utils.cpp_extension
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------
|
33 |
+
# Global options.
|
34 |
+
|
35 |
+
verbosity = 'none' # Verbosity level: 'none', 'brief', 'full'
|
36 |
+
|
37 |
+
#----------------------------------------------------------------------------
|
38 |
+
# Internal helper funcs.
|
39 |
+
|
40 |
+
def _find_compiler_bindir():
|
41 |
+
patterns = [
|
42 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
43 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
44 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
45 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
46 |
+
]
|
47 |
+
for pattern in patterns:
|
48 |
+
matches = sorted(glob.glob(pattern))
|
49 |
+
if len(matches):
|
50 |
+
return matches[-1]
|
51 |
+
return None
|
52 |
+
|
53 |
+
def _find_compiler_bindir_posix():
|
54 |
+
patterns = [
|
55 |
+
'/usr/local/cuda/bin'
|
56 |
+
]
|
57 |
+
for pattern in patterns:
|
58 |
+
matches = sorted(glob.glob(pattern))
|
59 |
+
if len(matches):
|
60 |
+
return matches[-1]
|
61 |
+
return None
|
62 |
+
|
63 |
+
#----------------------------------------------------------------------------
|
64 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
65 |
+
|
66 |
+
_cached_plugins = dict()
|
67 |
+
|
68 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
69 |
+
assert verbosity in ['none', 'brief', 'full']
|
70 |
+
|
71 |
+
# Already cached?
|
72 |
+
if module_name in _cached_plugins:
|
73 |
+
return _cached_plugins[module_name]
|
74 |
+
|
75 |
+
# Print status.
|
76 |
+
if verbosity == 'full':
|
77 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
78 |
+
elif verbosity == 'brief':
|
79 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
80 |
+
|
81 |
+
try: # pylint: disable=too-many-nested-blocks
|
82 |
+
# Make sure we can find the necessary compiler binaries.
|
83 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
84 |
+
compiler_bindir = _find_compiler_bindir()
|
85 |
+
if compiler_bindir is None:
|
86 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
87 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
88 |
+
|
89 |
+
elif os.name == 'posix':
|
90 |
+
compiler_bindir = _find_compiler_bindir_posix()
|
91 |
+
if compiler_bindir is None:
|
92 |
+
raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".')
|
93 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
94 |
+
|
95 |
+
# Compile and load.
|
96 |
+
verbose_build = (verbosity == 'full')
|
97 |
+
|
98 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
99 |
+
# into a cached build directory under a combined md5 digest of the input
|
100 |
+
# source files. Copying is done only if the combined digest has changed.
|
101 |
+
# This keeps input file timestamps and filenames the same as in previous
|
102 |
+
# extension builds, allowing for fast incremental rebuilds.
|
103 |
+
#
|
104 |
+
# This optimization is done only in case all the source files reside in
|
105 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
106 |
+
# environment variable is set (we take this as a signal that the user
|
107 |
+
# actually cares about this.)
|
108 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
109 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
110 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
111 |
+
|
112 |
+
# Compute a combined hash digest for all source files in the same
|
113 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
114 |
+
hash_md5 = hashlib.md5()
|
115 |
+
for src in all_source_files:
|
116 |
+
with open(src, 'rb') as f:
|
117 |
+
hash_md5.update(f.read())
|
118 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
119 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
120 |
+
|
121 |
+
if not os.path.isdir(digest_build_dir):
|
122 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
123 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
124 |
+
if baton.try_acquire():
|
125 |
+
try:
|
126 |
+
for src in all_source_files:
|
127 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
128 |
+
finally:
|
129 |
+
baton.release()
|
130 |
+
else:
|
131 |
+
# Someone else is copying source files under the digest dir,
|
132 |
+
# wait until done and continue.
|
133 |
+
baton.wait()
|
134 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
135 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
136 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
137 |
+
else:
|
138 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
139 |
+
module = importlib.import_module(module_name)
|
140 |
+
|
141 |
+
except:
|
142 |
+
if verbosity == 'brief':
|
143 |
+
print('Failed!')
|
144 |
+
raise
|
145 |
+
|
146 |
+
# Print status and add to cache.
|
147 |
+
if verbosity == 'full':
|
148 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
149 |
+
elif verbosity == 'brief':
|
150 |
+
print('Done.')
|
151 |
+
_cached_plugins[module_name] = module
|
152 |
+
return module
|
153 |
+
|
154 |
+
#----------------------------------------------------------------------------
|
155 |
+
|
156 |
+
# pylint: enable=line-too-long
|
157 |
+
# pylint: enable=missing-function-docstring
|
158 |
+
# pylint: enable=useless-suppression
|
159 |
+
# pylint: enable=inconsistent-quotes
|
third_party/stylegan2_official_ops/fma.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.
|
12 |
+
|
13 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
14 |
+
"""
|
15 |
+
|
16 |
+
# pylint: disable=line-too-long
|
17 |
+
# pylint: disable=missing-function-docstring
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
def fma(a, b, c, impl='cuda'): # => a * b + c
|
24 |
+
if impl == 'cuda':
|
25 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
26 |
+
return torch.addcmul(c, a, b)
|
27 |
+
|
28 |
+
#----------------------------------------------------------------------------
|
29 |
+
|
30 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
31 |
+
@staticmethod
|
32 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
33 |
+
out = torch.addcmul(c, a, b)
|
34 |
+
ctx.save_for_backward(a, b)
|
35 |
+
ctx.c_shape = c.shape
|
36 |
+
return out
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
40 |
+
a, b = ctx.saved_tensors
|
41 |
+
c_shape = ctx.c_shape
|
42 |
+
da = None
|
43 |
+
db = None
|
44 |
+
dc = None
|
45 |
+
|
46 |
+
if ctx.needs_input_grad[0]:
|
47 |
+
da = _unbroadcast(dout * b, a.shape)
|
48 |
+
|
49 |
+
if ctx.needs_input_grad[1]:
|
50 |
+
db = _unbroadcast(dout * a, b.shape)
|
51 |
+
|
52 |
+
if ctx.needs_input_grad[2]:
|
53 |
+
dc = _unbroadcast(dout, c_shape)
|
54 |
+
|
55 |
+
return da, db, dc
|
56 |
+
|
57 |
+
#----------------------------------------------------------------------------
|
58 |
+
|
59 |
+
def _unbroadcast(x, shape):
|
60 |
+
extra_dims = x.ndim - len(shape)
|
61 |
+
assert extra_dims >= 0
|
62 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
63 |
+
if len(dim):
|
64 |
+
x = x.sum(dim=dim, keepdim=True)
|
65 |
+
if extra_dims:
|
66 |
+
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
67 |
+
assert x.shape == shape
|
68 |
+
return x
|
69 |
+
|
70 |
+
#----------------------------------------------------------------------------
|
71 |
+
|
72 |
+
# pylint: enable=line-too-long
|
73 |
+
# pylint: enable=missing-function-docstring
|
third_party/stylegan2_official_ops/grid_sample_gradfix.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""Custom replacement for `torch.nn.functional.grid_sample`.
|
12 |
+
|
13 |
+
This is useful for differentiable augmentation. This customized operator
|
14 |
+
supports arbitrarily high order gradients between the input and output. Only
|
15 |
+
works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and
|
16 |
+
`align_corners=False`.
|
17 |
+
|
18 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
19 |
+
"""
|
20 |
+
|
21 |
+
# pylint: disable=redefined-builtin
|
22 |
+
# pylint: disable=arguments-differ
|
23 |
+
# pylint: disable=protected-access
|
24 |
+
# pylint: disable=line-too-long
|
25 |
+
# pylint: disable=missing-function-docstring
|
26 |
+
|
27 |
+
import warnings
|
28 |
+
import torch
|
29 |
+
|
30 |
+
#----------------------------------------------------------------------------
|
31 |
+
|
32 |
+
enabled = True # Enable the custom op by setting this to true.
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def grid_sample(input, grid, impl='cuda'):
|
37 |
+
if impl == 'cuda' and _should_use_custom_op():
|
38 |
+
return _GridSample2dForward.apply(input, grid)
|
39 |
+
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
42 |
+
|
43 |
+
def _should_use_custom_op():
|
44 |
+
if not enabled:
|
45 |
+
return False
|
46 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
47 |
+
return True
|
48 |
+
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
|
49 |
+
return False
|
50 |
+
|
51 |
+
#----------------------------------------------------------------------------
|
52 |
+
|
53 |
+
class _GridSample2dForward(torch.autograd.Function):
|
54 |
+
@staticmethod
|
55 |
+
def forward(ctx, input, grid):
|
56 |
+
assert input.ndim == 4
|
57 |
+
assert grid.ndim == 4
|
58 |
+
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
59 |
+
ctx.save_for_backward(input, grid)
|
60 |
+
return output
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def backward(ctx, grad_output):
|
64 |
+
input, grid = ctx.saved_tensors
|
65 |
+
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
66 |
+
return grad_input, grad_grid
|
67 |
+
|
68 |
+
#----------------------------------------------------------------------------
|
69 |
+
|
70 |
+
class _GridSample2dBackward(torch.autograd.Function):
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, grad_output, input, grid):
|
73 |
+
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
74 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
75 |
+
ctx.save_for_backward(grid)
|
76 |
+
return grad_input, grad_grid
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
80 |
+
_ = grad2_grad_grid # unused
|
81 |
+
grid, = ctx.saved_tensors
|
82 |
+
grad2_grad_output = None
|
83 |
+
grad2_input = None
|
84 |
+
grad2_grid = None
|
85 |
+
|
86 |
+
if ctx.needs_input_grad[0]:
|
87 |
+
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
88 |
+
|
89 |
+
assert not ctx.needs_input_grad[2]
|
90 |
+
return grad2_grad_output, grad2_input, grad2_grid
|
91 |
+
|
92 |
+
#----------------------------------------------------------------------------
|
93 |
+
|
94 |
+
# pylint: enable=redefined-builtin
|
95 |
+
# pylint: enable=arguments-differ
|
96 |
+
# pylint: enable=protected-access
|
97 |
+
# pylint: enable=line-too-long
|
98 |
+
# pylint: enable=missing-function-docstring
|
third_party/stylegan2_official_ops/misc.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""Misc functions for customized operations.
|
12 |
+
|
13 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
14 |
+
"""
|
15 |
+
|
16 |
+
# pylint: disable=line-too-long
|
17 |
+
# pylint: disable=missing-class-docstring
|
18 |
+
# pylint: disable=missing-function-docstring
|
19 |
+
# pylint: disable=use-maxsplit-arg
|
20 |
+
# pylint: disable=unnecessary-comprehension
|
21 |
+
|
22 |
+
import re
|
23 |
+
import contextlib
|
24 |
+
import warnings
|
25 |
+
from easydict import EasyDict
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
|
29 |
+
#----------------------------------------------------------------------------
|
30 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
31 |
+
# same constant is used multiple times.
|
32 |
+
|
33 |
+
_constant_cache = dict()
|
34 |
+
|
35 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
36 |
+
value = np.asarray(value)
|
37 |
+
if shape is not None:
|
38 |
+
shape = tuple(shape)
|
39 |
+
if dtype is None:
|
40 |
+
dtype = torch.get_default_dtype()
|
41 |
+
if device is None:
|
42 |
+
device = torch.device('cpu')
|
43 |
+
if memory_format is None:
|
44 |
+
memory_format = torch.contiguous_format
|
45 |
+
|
46 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
47 |
+
tensor = _constant_cache.get(key, None)
|
48 |
+
if tensor is None:
|
49 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
50 |
+
if shape is not None:
|
51 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
52 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
53 |
+
_constant_cache[key] = tensor
|
54 |
+
return tensor
|
55 |
+
|
56 |
+
#----------------------------------------------------------------------------
|
57 |
+
# Replace NaN/Inf with specified numerical values.
|
58 |
+
|
59 |
+
try:
|
60 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
61 |
+
except AttributeError:
|
62 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
63 |
+
assert isinstance(input, torch.Tensor)
|
64 |
+
if posinf is None:
|
65 |
+
posinf = torch.finfo(input.dtype).max
|
66 |
+
if neginf is None:
|
67 |
+
neginf = torch.finfo(input.dtype).min
|
68 |
+
assert nan == 0
|
69 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
70 |
+
|
71 |
+
#----------------------------------------------------------------------------
|
72 |
+
# Symbolic assert.
|
73 |
+
|
74 |
+
try:
|
75 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
76 |
+
except AttributeError:
|
77 |
+
symbolic_assert = torch.Assert # 1.7.0
|
78 |
+
|
79 |
+
#----------------------------------------------------------------------------
|
80 |
+
# Context manager to suppress known warnings in torch.jit.trace().
|
81 |
+
|
82 |
+
class suppress_tracer_warnings(warnings.catch_warnings):
|
83 |
+
def __enter__(self):
|
84 |
+
super().__enter__()
|
85 |
+
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
86 |
+
return self
|
87 |
+
|
88 |
+
#----------------------------------------------------------------------------
|
89 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
90 |
+
# None indicates that the size of a dimension is allowed to vary.
|
91 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
92 |
+
|
93 |
+
def assert_shape(tensor, ref_shape):
|
94 |
+
if tensor.ndim != len(ref_shape):
|
95 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
96 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
97 |
+
if ref_size is None:
|
98 |
+
pass
|
99 |
+
elif isinstance(ref_size, torch.Tensor):
|
100 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
101 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
102 |
+
elif isinstance(size, torch.Tensor):
|
103 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
104 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
105 |
+
elif size != ref_size:
|
106 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
107 |
+
|
108 |
+
#----------------------------------------------------------------------------
|
109 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
110 |
+
|
111 |
+
def profiled_function(fn):
|
112 |
+
def decorator(*args, **kwargs):
|
113 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
114 |
+
return fn(*args, **kwargs)
|
115 |
+
decorator.__name__ = fn.__name__
|
116 |
+
return decorator
|
117 |
+
|
118 |
+
#----------------------------------------------------------------------------
|
119 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
120 |
+
# indefinitely, shuffling items as it goes.
|
121 |
+
|
122 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
123 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
124 |
+
assert len(dataset) > 0
|
125 |
+
assert num_replicas > 0
|
126 |
+
assert 0 <= rank < num_replicas
|
127 |
+
assert 0 <= window_size <= 1
|
128 |
+
super().__init__(dataset)
|
129 |
+
self.dataset = dataset
|
130 |
+
self.rank = rank
|
131 |
+
self.num_replicas = num_replicas
|
132 |
+
self.shuffle = shuffle
|
133 |
+
self.seed = seed
|
134 |
+
self.window_size = window_size
|
135 |
+
|
136 |
+
def __iter__(self):
|
137 |
+
order = np.arange(len(self.dataset))
|
138 |
+
rnd = None
|
139 |
+
window = 0
|
140 |
+
if self.shuffle:
|
141 |
+
rnd = np.random.RandomState(self.seed)
|
142 |
+
rnd.shuffle(order)
|
143 |
+
window = int(np.rint(order.size * self.window_size))
|
144 |
+
|
145 |
+
idx = 0
|
146 |
+
while True:
|
147 |
+
i = idx % order.size
|
148 |
+
if idx % self.num_replicas == self.rank:
|
149 |
+
yield order[i]
|
150 |
+
if window >= 2:
|
151 |
+
j = (i - rnd.randint(window)) % order.size
|
152 |
+
order[i], order[j] = order[j], order[i]
|
153 |
+
idx += 1
|
154 |
+
|
155 |
+
#----------------------------------------------------------------------------
|
156 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
157 |
+
|
158 |
+
def params_and_buffers(module):
|
159 |
+
assert isinstance(module, torch.nn.Module)
|
160 |
+
return list(module.parameters()) + list(module.buffers())
|
161 |
+
|
162 |
+
def named_params_and_buffers(module):
|
163 |
+
assert isinstance(module, torch.nn.Module)
|
164 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
165 |
+
|
166 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
167 |
+
assert isinstance(src_module, torch.nn.Module)
|
168 |
+
assert isinstance(dst_module, torch.nn.Module)
|
169 |
+
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
170 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
171 |
+
assert (name in src_tensors) or (not require_all)
|
172 |
+
if name in src_tensors:
|
173 |
+
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
174 |
+
|
175 |
+
#----------------------------------------------------------------------------
|
176 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
177 |
+
# synchronization.
|
178 |
+
|
179 |
+
@contextlib.contextmanager
|
180 |
+
def ddp_sync(module, sync):
|
181 |
+
assert isinstance(module, torch.nn.Module)
|
182 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
183 |
+
yield
|
184 |
+
else:
|
185 |
+
with module.no_sync():
|
186 |
+
yield
|
187 |
+
|
188 |
+
#----------------------------------------------------------------------------
|
189 |
+
# Check DistributedDataParallel consistency across processes.
|
190 |
+
|
191 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
192 |
+
assert isinstance(module, torch.nn.Module)
|
193 |
+
for name, tensor in named_params_and_buffers(module):
|
194 |
+
fullname = type(module).__name__ + '.' + name
|
195 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
196 |
+
continue
|
197 |
+
tensor = tensor.detach()
|
198 |
+
other = tensor.clone()
|
199 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
200 |
+
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
|
201 |
+
|
202 |
+
#----------------------------------------------------------------------------
|
203 |
+
# Print summary table of module hierarchy.
|
204 |
+
|
205 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
206 |
+
assert isinstance(module, torch.nn.Module)
|
207 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
208 |
+
assert isinstance(inputs, (tuple, list))
|
209 |
+
|
210 |
+
# Register hooks.
|
211 |
+
entries = []
|
212 |
+
nesting = [0]
|
213 |
+
def pre_hook(_mod, _inputs):
|
214 |
+
nesting[0] += 1
|
215 |
+
def post_hook(mod, _inputs, outputs):
|
216 |
+
nesting[0] -= 1
|
217 |
+
if nesting[0] <= max_nesting:
|
218 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
219 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
220 |
+
entries.append(EasyDict(mod=mod, outputs=outputs))
|
221 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
222 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
223 |
+
|
224 |
+
# Run module.
|
225 |
+
outputs = module(*inputs)
|
226 |
+
for hook in hooks:
|
227 |
+
hook.remove()
|
228 |
+
|
229 |
+
# Identify unique outputs, parameters, and buffers.
|
230 |
+
tensors_seen = set()
|
231 |
+
for e in entries:
|
232 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
233 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
234 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
235 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
236 |
+
|
237 |
+
# Filter out redundant entries.
|
238 |
+
if skip_redundant:
|
239 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
240 |
+
|
241 |
+
# Construct table.
|
242 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
243 |
+
rows += [['---'] * len(rows[0])]
|
244 |
+
param_total = 0
|
245 |
+
buffer_total = 0
|
246 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
247 |
+
for e in entries:
|
248 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
249 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
250 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
251 |
+
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
|
252 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
253 |
+
rows += [[
|
254 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
255 |
+
str(param_size) if param_size else '-',
|
256 |
+
str(buffer_size) if buffer_size else '-',
|
257 |
+
(output_shapes + ['-'])[0],
|
258 |
+
(output_dtypes + ['-'])[0],
|
259 |
+
]]
|
260 |
+
for idx in range(1, len(e.outputs)):
|
261 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
262 |
+
param_total += param_size
|
263 |
+
buffer_total += buffer_size
|
264 |
+
rows += [['---'] * len(rows[0])]
|
265 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
266 |
+
|
267 |
+
# Print table.
|
268 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
269 |
+
print()
|
270 |
+
for row in rows:
|
271 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
272 |
+
print()
|
273 |
+
return outputs
|
274 |
+
|
275 |
+
#----------------------------------------------------------------------------
|
276 |
+
|
277 |
+
# pylint: enable=line-too-long
|
278 |
+
# pylint: enable=missing-class-docstring
|
279 |
+
# pylint: enable=missing-function-docstring
|
280 |
+
# pylint: enable=use-maxsplit-arg
|
281 |
+
# pylint: enable=unnecessary-comprehension
|
third_party/stylegan2_official_ops/upfirdn2d.cpp
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "upfirdn2d.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
17 |
+
{
|
18 |
+
// Validate arguments.
|
19 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
20 |
+
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
21 |
+
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
22 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
23 |
+
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
24 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
25 |
+
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
26 |
+
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
27 |
+
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
28 |
+
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
29 |
+
|
30 |
+
// Create output tensor.
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
32 |
+
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
33 |
+
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
34 |
+
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
35 |
+
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
36 |
+
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
37 |
+
|
38 |
+
// Initialize CUDA kernel parameters.
|
39 |
+
upfirdn2d_kernel_params p;
|
40 |
+
p.x = x.data_ptr();
|
41 |
+
p.f = f.data_ptr<float>();
|
42 |
+
p.y = y.data_ptr();
|
43 |
+
p.up = make_int2(upx, upy);
|
44 |
+
p.down = make_int2(downx, downy);
|
45 |
+
p.pad0 = make_int2(padx0, pady0);
|
46 |
+
p.flip = (flip) ? 1 : 0;
|
47 |
+
p.gain = gain;
|
48 |
+
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
49 |
+
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
50 |
+
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
51 |
+
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
52 |
+
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
53 |
+
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
54 |
+
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
55 |
+
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
56 |
+
|
57 |
+
// Choose CUDA kernel.
|
58 |
+
upfirdn2d_kernel_spec spec;
|
59 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
60 |
+
{
|
61 |
+
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
62 |
+
});
|
63 |
+
|
64 |
+
// Set looping options.
|
65 |
+
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
66 |
+
p.loopMinor = spec.loopMinor;
|
67 |
+
p.loopX = spec.loopX;
|
68 |
+
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
69 |
+
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
70 |
+
|
71 |
+
// Compute grid size.
|
72 |
+
dim3 blockSize, gridSize;
|
73 |
+
if (spec.tileOutW < 0) // large
|
74 |
+
{
|
75 |
+
blockSize = dim3(4, 32, 1);
|
76 |
+
gridSize = dim3(
|
77 |
+
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
78 |
+
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
79 |
+
p.launchMajor);
|
80 |
+
}
|
81 |
+
else // small
|
82 |
+
{
|
83 |
+
blockSize = dim3(256, 1, 1);
|
84 |
+
gridSize = dim3(
|
85 |
+
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
86 |
+
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
87 |
+
p.launchMajor);
|
88 |
+
}
|
89 |
+
|
90 |
+
// Launch CUDA kernel.
|
91 |
+
void* args[] = {&p};
|
92 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
93 |
+
return y;
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
|
98 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
99 |
+
{
|
100 |
+
m.def("upfirdn2d", &upfirdn2d);
|
101 |
+
}
|
102 |
+
|
103 |
+
//------------------------------------------------------------------------
|
third_party/stylegan2_official_ops/upfirdn2d.cu
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "upfirdn2d.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
static __device__ __forceinline__ int floor_div(int a, int b)
|
21 |
+
{
|
22 |
+
int t = 1 - a / b;
|
23 |
+
return (a + t * b) / b - t;
|
24 |
+
}
|
25 |
+
|
26 |
+
//------------------------------------------------------------------------
|
27 |
+
// Generic CUDA implementation for large filters.
|
28 |
+
|
29 |
+
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
30 |
+
{
|
31 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
32 |
+
|
33 |
+
// Calculate thread index.
|
34 |
+
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
35 |
+
int outY = minorBase / p.launchMinor;
|
36 |
+
minorBase -= outY * p.launchMinor;
|
37 |
+
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
38 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
39 |
+
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
40 |
+
return;
|
41 |
+
|
42 |
+
// Setup Y receptive field.
|
43 |
+
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
44 |
+
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
45 |
+
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
46 |
+
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
47 |
+
if (p.flip)
|
48 |
+
filterY = p.filterSize.y - 1 - filterY;
|
49 |
+
|
50 |
+
// Loop over major, minor, and X.
|
51 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
52 |
+
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
53 |
+
{
|
54 |
+
int nc = major * p.sizeMinor + minor;
|
55 |
+
int n = nc / p.inSize.z;
|
56 |
+
int c = nc - n * p.inSize.z;
|
57 |
+
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
58 |
+
{
|
59 |
+
// Setup X receptive field.
|
60 |
+
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
61 |
+
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
62 |
+
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
63 |
+
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
64 |
+
if (p.flip)
|
65 |
+
filterX = p.filterSize.x - 1 - filterX;
|
66 |
+
|
67 |
+
// Initialize pointers.
|
68 |
+
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
69 |
+
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
70 |
+
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
71 |
+
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
72 |
+
|
73 |
+
// Inner loop.
|
74 |
+
scalar_t v = 0;
|
75 |
+
for (int y = 0; y < h; y++)
|
76 |
+
{
|
77 |
+
for (int x = 0; x < w; x++)
|
78 |
+
{
|
79 |
+
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
80 |
+
xp += p.inStride.x;
|
81 |
+
fp += filterStepX;
|
82 |
+
}
|
83 |
+
xp += p.inStride.y - w * p.inStride.x;
|
84 |
+
fp += filterStepY - w * filterStepX;
|
85 |
+
}
|
86 |
+
|
87 |
+
// Store result.
|
88 |
+
v *= p.gain;
|
89 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
//------------------------------------------------------------------------
|
95 |
+
// Specialized CUDA implementation for small filters.
|
96 |
+
|
97 |
+
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
98 |
+
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
99 |
+
{
|
100 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
101 |
+
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
102 |
+
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
103 |
+
__shared__ volatile scalar_t sf[filterH][filterW];
|
104 |
+
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
105 |
+
|
106 |
+
// Calculate tile index.
|
107 |
+
int minorBase = blockIdx.x;
|
108 |
+
int tileOutY = minorBase / p.launchMinor;
|
109 |
+
minorBase -= tileOutY * p.launchMinor;
|
110 |
+
minorBase *= loopMinor;
|
111 |
+
tileOutY *= tileOutH;
|
112 |
+
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
113 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
114 |
+
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
115 |
+
return;
|
116 |
+
|
117 |
+
// Load filter (flipped).
|
118 |
+
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
119 |
+
{
|
120 |
+
int fy = tapIdx / filterW;
|
121 |
+
int fx = tapIdx - fy * filterW;
|
122 |
+
scalar_t v = 0;
|
123 |
+
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
124 |
+
{
|
125 |
+
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
126 |
+
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
127 |
+
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
128 |
+
}
|
129 |
+
sf[fy][fx] = v;
|
130 |
+
}
|
131 |
+
|
132 |
+
// Loop over major and X.
|
133 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
134 |
+
{
|
135 |
+
int baseNC = major * p.sizeMinor + minorBase;
|
136 |
+
int n = baseNC / p.inSize.z;
|
137 |
+
int baseC = baseNC - n * p.inSize.z;
|
138 |
+
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
139 |
+
{
|
140 |
+
// Load input pixels.
|
141 |
+
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
142 |
+
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
143 |
+
int tileInX = floor_div(tileMidX, upx);
|
144 |
+
int tileInY = floor_div(tileMidY, upy);
|
145 |
+
__syncthreads();
|
146 |
+
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
147 |
+
{
|
148 |
+
int relC = inIdx;
|
149 |
+
int relInX = relC / loopMinor;
|
150 |
+
int relInY = relInX / tileInW;
|
151 |
+
relC -= relInX * loopMinor;
|
152 |
+
relInX -= relInY * tileInW;
|
153 |
+
int c = baseC + relC;
|
154 |
+
int inX = tileInX + relInX;
|
155 |
+
int inY = tileInY + relInY;
|
156 |
+
scalar_t v = 0;
|
157 |
+
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
158 |
+
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
159 |
+
sx[relInY][relInX][relC] = v;
|
160 |
+
}
|
161 |
+
|
162 |
+
// Loop over output pixels.
|
163 |
+
__syncthreads();
|
164 |
+
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
165 |
+
{
|
166 |
+
int relC = outIdx;
|
167 |
+
int relOutX = relC / loopMinor;
|
168 |
+
int relOutY = relOutX / tileOutW;
|
169 |
+
relC -= relOutX * loopMinor;
|
170 |
+
relOutX -= relOutY * tileOutW;
|
171 |
+
int c = baseC + relC;
|
172 |
+
int outX = tileOutX + relOutX;
|
173 |
+
int outY = tileOutY + relOutY;
|
174 |
+
|
175 |
+
// Setup receptive field.
|
176 |
+
int midX = tileMidX + relOutX * downx;
|
177 |
+
int midY = tileMidY + relOutY * downy;
|
178 |
+
int inX = floor_div(midX, upx);
|
179 |
+
int inY = floor_div(midY, upy);
|
180 |
+
int relInX = inX - tileInX;
|
181 |
+
int relInY = inY - tileInY;
|
182 |
+
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
183 |
+
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
184 |
+
|
185 |
+
// Inner loop.
|
186 |
+
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
187 |
+
{
|
188 |
+
scalar_t v = 0;
|
189 |
+
#pragma unroll
|
190 |
+
for (int y = 0; y < filterH / upy; y++)
|
191 |
+
#pragma unroll
|
192 |
+
for (int x = 0; x < filterW / upx; x++)
|
193 |
+
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
194 |
+
v *= p.gain;
|
195 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
|
202 |
+
//------------------------------------------------------------------------
|
203 |
+
// CUDA kernel selection.
|
204 |
+
|
205 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
206 |
+
{
|
207 |
+
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
208 |
+
|
209 |
+
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
210 |
+
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
211 |
+
|
212 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
213 |
+
{
|
214 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
215 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
216 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
217 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
218 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
219 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
220 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
221 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
222 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
223 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
224 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
225 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
226 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
227 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
228 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
229 |
+
}
|
230 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
231 |
+
{
|
232 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
233 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
234 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
235 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
236 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
237 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
238 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
239 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
240 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
241 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
242 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
243 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
244 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
245 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
246 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
247 |
+
}
|
248 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
249 |
+
{
|
250 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
251 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
252 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
253 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
254 |
+
}
|
255 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
256 |
+
{
|
257 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
258 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
259 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
260 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
261 |
+
}
|
262 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
263 |
+
{
|
264 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
265 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
266 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
267 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
268 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
269 |
+
}
|
270 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
271 |
+
{
|
272 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
273 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
274 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
275 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
276 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
277 |
+
}
|
278 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
279 |
+
{
|
280 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
281 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
282 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
283 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
284 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
285 |
+
}
|
286 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
287 |
+
{
|
288 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
289 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
290 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
291 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
292 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
293 |
+
}
|
294 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
295 |
+
{
|
296 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
297 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
298 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
299 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
300 |
+
}
|
301 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
302 |
+
{
|
303 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
304 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
305 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
306 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
307 |
+
}
|
308 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
309 |
+
{
|
310 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
311 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
312 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
313 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
314 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
315 |
+
}
|
316 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
317 |
+
{
|
318 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
319 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
320 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
321 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
322 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
323 |
+
}
|
324 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
325 |
+
{
|
326 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
327 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
328 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
329 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
330 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
331 |
+
}
|
332 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
333 |
+
{
|
334 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
335 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
336 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
337 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
338 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
339 |
+
}
|
340 |
+
return spec;
|
341 |
+
}
|
342 |
+
|
343 |
+
//------------------------------------------------------------------------
|
344 |
+
// Template specializations.
|
345 |
+
|
346 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
347 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
348 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
349 |
+
|
350 |
+
//------------------------------------------------------------------------
|
third_party/stylegan2_official_ops/upfirdn2d.h
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
//------------------------------------------------------------------------
|
12 |
+
// CUDA kernel parameters.
|
13 |
+
|
14 |
+
struct upfirdn2d_kernel_params
|
15 |
+
{
|
16 |
+
const void* x;
|
17 |
+
const float* f;
|
18 |
+
void* y;
|
19 |
+
|
20 |
+
int2 up;
|
21 |
+
int2 down;
|
22 |
+
int2 pad0;
|
23 |
+
int flip;
|
24 |
+
float gain;
|
25 |
+
|
26 |
+
int4 inSize; // [width, height, channel, batch]
|
27 |
+
int4 inStride;
|
28 |
+
int2 filterSize; // [width, height]
|
29 |
+
int2 filterStride;
|
30 |
+
int4 outSize; // [width, height, channel, batch]
|
31 |
+
int4 outStride;
|
32 |
+
int sizeMinor;
|
33 |
+
int sizeMajor;
|
34 |
+
|
35 |
+
int loopMinor;
|
36 |
+
int loopMajor;
|
37 |
+
int loopX;
|
38 |
+
int launchMinor;
|
39 |
+
int launchMajor;
|
40 |
+
};
|
41 |
+
|
42 |
+
//------------------------------------------------------------------------
|
43 |
+
// CUDA kernel specialization.
|
44 |
+
|
45 |
+
struct upfirdn2d_kernel_spec
|
46 |
+
{
|
47 |
+
void* kernel;
|
48 |
+
int tileOutW;
|
49 |
+
int tileOutH;
|
50 |
+
int loopMinor;
|
51 |
+
int loopX;
|
52 |
+
};
|
53 |
+
|
54 |
+
//------------------------------------------------------------------------
|
55 |
+
// CUDA kernel selection.
|
56 |
+
|
57 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
58 |
+
|
59 |
+
//------------------------------------------------------------------------
|
third_party/stylegan2_official_ops/upfirdn2d.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
|
3 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
+
# and proprietary rights in and to this software, related documentation
|
7 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
+
# distribution of this software and related documentation without an express
|
9 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
+
|
11 |
+
"""Custom operators for efficient resampling of 2D images.
|
12 |
+
|
13 |
+
`upfirdn` means executing upsampling, FIR filtering, downsampling in sequence.
|
14 |
+
|
15 |
+
Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
|
16 |
+
"""
|
17 |
+
|
18 |
+
# pylint: disable=line-too-long
|
19 |
+
# pylint: disable=missing-class-docstring
|
20 |
+
# pylint: disable=global-variable-not-assigned
|
21 |
+
# pylint: disable=bare-except
|
22 |
+
|
23 |
+
import os
|
24 |
+
import warnings
|
25 |
+
import traceback
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
|
29 |
+
from . import custom_ops
|
30 |
+
from . import misc
|
31 |
+
from . import conv2d_gradfix
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
_inited = False
|
36 |
+
_plugin = None
|
37 |
+
|
38 |
+
def _init():
|
39 |
+
global _inited, _plugin
|
40 |
+
if not _inited:
|
41 |
+
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
42 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
43 |
+
try:
|
44 |
+
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
45 |
+
except:
|
46 |
+
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
47 |
+
return _plugin is not None
|
48 |
+
|
49 |
+
def _parse_scaling(scaling):
|
50 |
+
if isinstance(scaling, int):
|
51 |
+
scaling = [scaling, scaling]
|
52 |
+
assert isinstance(scaling, (list, tuple))
|
53 |
+
assert all(isinstance(x, int) for x in scaling)
|
54 |
+
sx, sy = scaling
|
55 |
+
assert sx >= 1 and sy >= 1
|
56 |
+
return sx, sy
|
57 |
+
|
58 |
+
def _parse_padding(padding):
|
59 |
+
if isinstance(padding, int):
|
60 |
+
padding = [padding, padding]
|
61 |
+
assert isinstance(padding, (list, tuple))
|
62 |
+
assert all(isinstance(x, int) for x in padding)
|
63 |
+
if len(padding) == 2:
|
64 |
+
padx, pady = padding
|
65 |
+
padding = [padx, padx, pady, pady]
|
66 |
+
padx0, padx1, pady0, pady1 = padding
|
67 |
+
return padx0, padx1, pady0, pady1
|
68 |
+
|
69 |
+
def _get_filter_size(f):
|
70 |
+
if f is None:
|
71 |
+
return 1, 1
|
72 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
73 |
+
fw = f.shape[-1]
|
74 |
+
fh = f.shape[0]
|
75 |
+
with misc.suppress_tracer_warnings():
|
76 |
+
fw = int(fw)
|
77 |
+
fh = int(fh)
|
78 |
+
misc.assert_shape(f, [fh, fw][:f.ndim])
|
79 |
+
assert fw >= 1 and fh >= 1
|
80 |
+
return fw, fh
|
81 |
+
|
82 |
+
#----------------------------------------------------------------------------
|
83 |
+
|
84 |
+
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
85 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
f: Torch tensor, numpy array, or python list of the shape
|
89 |
+
`[filter_height, filter_width]` (non-separable),
|
90 |
+
`[filter_taps]` (separable),
|
91 |
+
`[]` (impulse), or
|
92 |
+
`None` (identity).
|
93 |
+
device: Result device (default: cpu).
|
94 |
+
normalize: Normalize the filter so that it retains the magnitude
|
95 |
+
for constant input signal (DC)? (default: True).
|
96 |
+
flip_filter: Flip the filter? (default: False).
|
97 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
98 |
+
separable: Return a separable filter? (default: select automatically).
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Float32 tensor of the shape
|
102 |
+
`[filter_height, filter_width]` (non-separable) or
|
103 |
+
`[filter_taps]` (separable).
|
104 |
+
"""
|
105 |
+
# Validate.
|
106 |
+
if f is None:
|
107 |
+
f = 1
|
108 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
109 |
+
assert f.ndim in [0, 1, 2]
|
110 |
+
assert f.numel() > 0
|
111 |
+
if f.ndim == 0:
|
112 |
+
f = f[np.newaxis]
|
113 |
+
|
114 |
+
# Separable?
|
115 |
+
if separable is None:
|
116 |
+
separable = (f.ndim == 1 and f.numel() >= 8)
|
117 |
+
if f.ndim == 1 and not separable:
|
118 |
+
f = f.ger(f)
|
119 |
+
assert f.ndim == (1 if separable else 2)
|
120 |
+
|
121 |
+
# Apply normalize, flip, gain, and device.
|
122 |
+
if normalize:
|
123 |
+
f /= f.sum()
|
124 |
+
if flip_filter:
|
125 |
+
f = f.flip(list(range(f.ndim)))
|
126 |
+
f = f * (gain ** (f.ndim / 2))
|
127 |
+
f = f.to(device=device)
|
128 |
+
return f
|
129 |
+
|
130 |
+
#----------------------------------------------------------------------------
|
131 |
+
|
132 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
133 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
134 |
+
|
135 |
+
Performs the following sequence of operations for each channel:
|
136 |
+
|
137 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
138 |
+
|
139 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
140 |
+
Negative padding corresponds to cropping the image.
|
141 |
+
|
142 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
143 |
+
so that the footprint of all output pixels lies within the input image.
|
144 |
+
|
145 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
146 |
+
|
147 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
148 |
+
The fused op is considerably more efficient than performing the same calculation
|
149 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
x: Float32/float64/float16 input tensor of the shape
|
153 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
154 |
+
f: Float32 FIR filter of the shape
|
155 |
+
`[filter_height, filter_width]` (non-separable),
|
156 |
+
`[filter_taps]` (separable), or
|
157 |
+
`None` (identity).
|
158 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
159 |
+
`[x, y]` (default: 1).
|
160 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
161 |
+
`[x, y]` (default: 1).
|
162 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
163 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
164 |
+
(default: 0).
|
165 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
166 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
167 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
171 |
+
"""
|
172 |
+
assert isinstance(x, torch.Tensor)
|
173 |
+
assert impl in ['ref', 'cuda']
|
174 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
175 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
176 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
177 |
+
|
178 |
+
#----------------------------------------------------------------------------
|
179 |
+
|
180 |
+
@misc.profiled_function
|
181 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
182 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
183 |
+
"""
|
184 |
+
# Validate arguments.
|
185 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
186 |
+
if f is None:
|
187 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
188 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
189 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
190 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
191 |
+
upx, upy = _parse_scaling(up)
|
192 |
+
downx, downy = _parse_scaling(down)
|
193 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
194 |
+
|
195 |
+
# Upsample by inserting zeros.
|
196 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
197 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
198 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
199 |
+
|
200 |
+
# Pad or crop.
|
201 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
202 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
203 |
+
|
204 |
+
# Setup filter.
|
205 |
+
f = f * (gain ** (f.ndim / 2))
|
206 |
+
f = f.to(x.dtype)
|
207 |
+
if not flip_filter:
|
208 |
+
f = f.flip(list(range(f.ndim)))
|
209 |
+
|
210 |
+
# Convolve with the filter.
|
211 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
212 |
+
if f.ndim == 4:
|
213 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
214 |
+
else:
|
215 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
216 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
217 |
+
|
218 |
+
# Downsample by throwing away pixels.
|
219 |
+
x = x[:, :, ::downy, ::downx]
|
220 |
+
return x
|
221 |
+
|
222 |
+
#----------------------------------------------------------------------------
|
223 |
+
|
224 |
+
_upfirdn2d_cuda_cache = dict()
|
225 |
+
|
226 |
+
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
227 |
+
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
228 |
+
"""
|
229 |
+
# Parse arguments.
|
230 |
+
upx, upy = _parse_scaling(up)
|
231 |
+
downx, downy = _parse_scaling(down)
|
232 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
233 |
+
|
234 |
+
# Lookup from cache.
|
235 |
+
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
236 |
+
if key in _upfirdn2d_cuda_cache:
|
237 |
+
return _upfirdn2d_cuda_cache[key]
|
238 |
+
|
239 |
+
# Forward op.
|
240 |
+
class Upfirdn2dCuda(torch.autograd.Function):
|
241 |
+
@staticmethod
|
242 |
+
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
243 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
244 |
+
if f is None:
|
245 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
246 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
247 |
+
y = x
|
248 |
+
if f.ndim == 2:
|
249 |
+
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
250 |
+
else:
|
251 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
252 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
253 |
+
ctx.save_for_backward(f)
|
254 |
+
ctx.x_shape = x.shape
|
255 |
+
return y
|
256 |
+
|
257 |
+
@staticmethod
|
258 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
259 |
+
f, = ctx.saved_tensors
|
260 |
+
_, _, ih, iw = ctx.x_shape
|
261 |
+
_, _, oh, ow = dy.shape
|
262 |
+
fw, fh = _get_filter_size(f)
|
263 |
+
p = [
|
264 |
+
fw - padx0 - 1,
|
265 |
+
iw * upx - ow * downx + padx0 - upx + 1,
|
266 |
+
fh - pady0 - 1,
|
267 |
+
ih * upy - oh * downy + pady0 - upy + 1,
|
268 |
+
]
|
269 |
+
dx = None
|
270 |
+
df = None
|
271 |
+
|
272 |
+
if ctx.needs_input_grad[0]:
|
273 |
+
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
274 |
+
|
275 |
+
assert not ctx.needs_input_grad[1]
|
276 |
+
return dx, df
|
277 |
+
|
278 |
+
# Add to cache.
|
279 |
+
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
280 |
+
return Upfirdn2dCuda
|
281 |
+
|
282 |
+
#----------------------------------------------------------------------------
|
283 |
+
|
284 |
+
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
285 |
+
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
286 |
+
|
287 |
+
By default, the result is padded so that its shape matches the input.
|
288 |
+
User-specified padding is applied on top of that, with negative values
|
289 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
x: Float32/float64/float16 input tensor of the shape
|
293 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
294 |
+
f: Float32 FIR filter of the shape
|
295 |
+
`[filter_height, filter_width]` (non-separable),
|
296 |
+
`[filter_taps]` (separable), or
|
297 |
+
`None` (identity).
|
298 |
+
padding: Padding with respect to the output. Can be a single number or a
|
299 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
300 |
+
(default: 0).
|
301 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
302 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
303 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
307 |
+
"""
|
308 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
309 |
+
fw, fh = _get_filter_size(f)
|
310 |
+
p = [
|
311 |
+
padx0 + fw // 2,
|
312 |
+
padx1 + (fw - 1) // 2,
|
313 |
+
pady0 + fh // 2,
|
314 |
+
pady1 + (fh - 1) // 2,
|
315 |
+
]
|
316 |
+
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
317 |
+
|
318 |
+
#----------------------------------------------------------------------------
|
319 |
+
|
320 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
321 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
322 |
+
|
323 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
324 |
+
User-specified padding is applied on top of that, with negative values
|
325 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
x: Float32/float64/float16 input tensor of the shape
|
329 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
330 |
+
f: Float32 FIR filter of the shape
|
331 |
+
`[filter_height, filter_width]` (non-separable),
|
332 |
+
`[filter_taps]` (separable), or
|
333 |
+
`None` (identity).
|
334 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
335 |
+
`[x, y]` (default: 1).
|
336 |
+
padding: Padding with respect to the output. Can be a single number or a
|
337 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
338 |
+
(default: 0).
|
339 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
340 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
341 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
345 |
+
"""
|
346 |
+
upx, upy = _parse_scaling(up)
|
347 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
348 |
+
fw, fh = _get_filter_size(f)
|
349 |
+
p = [
|
350 |
+
padx0 + (fw + upx - 1) // 2,
|
351 |
+
padx1 + (fw - upx) // 2,
|
352 |
+
pady0 + (fh + upy - 1) // 2,
|
353 |
+
pady1 + (fh - upy) // 2,
|
354 |
+
]
|
355 |
+
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
356 |
+
|
357 |
+
#----------------------------------------------------------------------------
|
358 |
+
|
359 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
360 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
361 |
+
|
362 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
363 |
+
User-specified padding is applied on top of that, with negative values
|
364 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
x: Float32/float64/float16 input tensor of the shape
|
368 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
369 |
+
f: Float32 FIR filter of the shape
|
370 |
+
`[filter_height, filter_width]` (non-separable),
|
371 |
+
`[filter_taps]` (separable), or
|
372 |
+
`None` (identity).
|
373 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
374 |
+
`[x, y]` (default: 1).
|
375 |
+
padding: Padding with respect to the input. Can be a single number or a
|
376 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
377 |
+
(default: 0).
|
378 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
379 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
380 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
384 |
+
"""
|
385 |
+
downx, downy = _parse_scaling(down)
|
386 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
387 |
+
fw, fh = _get_filter_size(f)
|
388 |
+
p = [
|
389 |
+
padx0 + (fw - downx + 1) // 2,
|
390 |
+
padx1 + (fw - downx) // 2,
|
391 |
+
pady0 + (fh - downy + 1) // 2,
|
392 |
+
pady1 + (fh - downy) // 2,
|
393 |
+
]
|
394 |
+
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
395 |
+
|
396 |
+
#----------------------------------------------------------------------------
|
397 |
+
|
398 |
+
# pylint: enable=line-too-long
|
399 |
+
# pylint: enable=missing-class-docstring
|
400 |
+
# pylint: enable=global-variable-not-assigned
|
401 |
+
# pylint: enable=bare-except
|
third_party/stylegan3_official_ops/README.md
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Operators for StyleGAN2
|
2 |
+
|
3 |
+
All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
|
4 |
+
|
5 |
+
- `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
|
6 |
+
- `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
|
7 |
+
- `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
|
8 |
+
- `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
|
9 |
+
- `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
|
10 |
+
- `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
|
11 |
+
- `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing.
|
12 |
+
- `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
|
13 |
+
- `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
|
14 |
+
- `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
|
15 |
+
|
16 |
+
We make following slight modifications beyond disabling some lint warnings:
|
17 |
+
|
18 |
+
- Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
|
19 |
+
- Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators.
|
20 |
+
- Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
|
21 |
+
- Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
|
22 |
+
- Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`.
|
23 |
+
- Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
|
24 |
+
- Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
|
25 |
+
- Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
|
26 |
+
- Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
|
27 |
+
- Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
|
28 |
+
- Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
|
29 |
+
|
30 |
+
Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.
|