Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
3dcdf92
1
Parent(s):
e96a195
- memory efficient EMA
Browse files- fix gradient checkpointing
- evaluate using image reward paper
- use attn_resolution in test
- EMA.py +13 -4
- run.py +29 -1
- score_sde/models/ncsnpp_generator_adagn.py +18 -10
- test_ddgan.py +29 -14
- train_ddgan.py +37 -53
EMA.py
CHANGED
@@ -15,13 +15,14 @@ from torch.optim import Optimizer
|
|
15 |
|
16 |
|
17 |
class EMA(Optimizer):
|
18 |
-
def __init__(self, opt, ema_decay):
|
19 |
self.ema_decay = ema_decay
|
20 |
self.apply_ema = self.ema_decay > 0.
|
21 |
self.optimizer = opt
|
22 |
self.state = opt.state
|
23 |
self.param_groups = opt.param_groups
|
24 |
self.defaults = {}
|
|
|
25 |
|
26 |
def step(self, *args, **kwargs):
|
27 |
# for group in self.optimizer.param_groups:
|
@@ -53,11 +54,19 @@ class EMA(Optimizer):
|
|
53 |
|
54 |
params[p.shape]['data'].append(p.data)
|
55 |
ema[p.shape].append(state['ema'])
|
|
|
|
|
|
|
56 |
|
57 |
for i in params:
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
for p in group['params']:
|
63 |
if p.grad is None:
|
|
|
15 |
|
16 |
|
17 |
class EMA(Optimizer):
|
18 |
+
def __init__(self, opt, ema_decay, memory_efficient=False):
|
19 |
self.ema_decay = ema_decay
|
20 |
self.apply_ema = self.ema_decay > 0.
|
21 |
self.optimizer = opt
|
22 |
self.state = opt.state
|
23 |
self.param_groups = opt.param_groups
|
24 |
self.defaults = {}
|
25 |
+
self.memory_efficient = memory_efficient
|
26 |
|
27 |
def step(self, *args, **kwargs):
|
28 |
# for group in self.optimizer.param_groups:
|
|
|
54 |
|
55 |
params[p.shape]['data'].append(p.data)
|
56 |
ema[p.shape].append(state['ema'])
|
57 |
+
|
58 |
+
# def stack(d, dim=0):
|
59 |
+
# return torch.stack([di.cpu() for di in d], dim=dim).cuda()
|
60 |
|
61 |
for i in params:
|
62 |
+
if self.memory_efficient:
|
63 |
+
for j in range(len(params[i]['data'])):
|
64 |
+
ema[i][j].mul_(self.ema_decay).add_(params[i]['data'][j], alpha=1. - self.ema_decay)
|
65 |
+
ema[i] = torch.stack(ema[i], dim=0)
|
66 |
+
else:
|
67 |
+
params[i]['data'] = torch.stack(params[i]['data'], dim=0)
|
68 |
+
ema[i] = torch.stack(ema[i], dim=0)
|
69 |
+
ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)
|
70 |
|
71 |
for p in group['params']:
|
72 |
if p.grad is None:
|
run.py
CHANGED
@@ -274,10 +274,30 @@ def ddgan_ddb_v7():
|
|
274 |
cfg = ddgan_ddb_v1()
|
275 |
return cfg
|
276 |
|
|
|
|
|
|
|
|
|
|
|
277 |
def ddgan_laion_aesthetic_v15():
|
278 |
cfg = ddgan_ddb_v3()
|
279 |
return cfg
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
models = [
|
282 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
283 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
@@ -326,6 +346,10 @@ models = [
|
|
326 |
ddgan_ddb_v5,
|
327 |
ddgan_ddb_v6,
|
328 |
ddgan_ddb_v7,
|
|
|
|
|
|
|
|
|
329 |
]
|
330 |
|
331 |
def get_model(model_name):
|
@@ -334,7 +358,7 @@ def get_model(model_name):
|
|
334 |
return model()
|
335 |
|
336 |
|
337 |
-
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional"):
|
338 |
|
339 |
cfg = get_model(model_name)
|
340 |
model = cfg['model']
|
@@ -365,12 +389,16 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
|
|
365 |
args['scale_factor_w'] = scale_factor_w
|
366 |
args['n_mlp'] = model.get("n_mlp")
|
367 |
args['scale_method'] = scale_method
|
|
|
368 |
if fid:
|
369 |
args['compute_fid'] = ''
|
370 |
args['real_img_dir'] = real_img_dir
|
371 |
args['nb_images_for_fid'] = nb_images_for_fid
|
372 |
if compute_clip_score:
|
373 |
args['compute_clip_score'] = ""
|
|
|
|
|
|
|
374 |
if eval_name:
|
375 |
args["eval_name"] = eval_name
|
376 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
|
|
274 |
cfg = ddgan_ddb_v1()
|
275 |
return cfg
|
276 |
|
277 |
+
def ddgan_ddb_v9():
|
278 |
+
cfg = ddgan_ddb_v3()
|
279 |
+
cfg['model']['attn_resolutions'] = '4 8 16 32'
|
280 |
+
return cfg
|
281 |
+
|
282 |
def ddgan_laion_aesthetic_v15():
|
283 |
cfg = ddgan_ddb_v3()
|
284 |
return cfg
|
285 |
|
286 |
+
def ddgan_ddb_v10():
|
287 |
+
cfg = ddgan_ddb_v9()
|
288 |
+
return cfg
|
289 |
+
|
290 |
+
def ddgan_ddb_v11():
|
291 |
+
cfg = ddgan_ddb_v3()
|
292 |
+
cfg['model']['text_encoder'] = "openclip/ViT-g-14/laion2B-s12B-b42K"
|
293 |
+
return cfg
|
294 |
+
|
295 |
+
def ddgan_ddb_v12():
|
296 |
+
cfg = ddgan_ddb_v3()
|
297 |
+
cfg['model']['text_encoder'] = "openclip/ViT-bigG-14/laion2b_s39b_b160k"
|
298 |
+
return cfg
|
299 |
+
|
300 |
+
|
301 |
models = [
|
302 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
303 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
346 |
ddgan_ddb_v5,
|
347 |
ddgan_ddb_v6,
|
348 |
ddgan_ddb_v7,
|
349 |
+
ddgan_ddb_v9,
|
350 |
+
ddgan_ddb_v10,
|
351 |
+
ddgan_ddb_v11,
|
352 |
+
ddgan_ddb_v12,
|
353 |
]
|
354 |
|
355 |
def get_model(model_name):
|
|
|
358 |
return model()
|
359 |
|
360 |
|
361 |
+
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional", compute_image_reward=False):
|
362 |
|
363 |
cfg = get_model(model_name)
|
364 |
model = cfg['model']
|
|
|
389 |
args['scale_factor_w'] = scale_factor_w
|
390 |
args['n_mlp'] = model.get("n_mlp")
|
391 |
args['scale_method'] = scale_method
|
392 |
+
args['attn_resolutions'] = model.get("attn_resolutions", "16")
|
393 |
if fid:
|
394 |
args['compute_fid'] = ''
|
395 |
args['real_img_dir'] = real_img_dir
|
396 |
args['nb_images_for_fid'] = nb_images_for_fid
|
397 |
if compute_clip_score:
|
398 |
args['compute_clip_score'] = ""
|
399 |
+
|
400 |
+
if compute_image_reward:
|
401 |
+
args['compute_image_reward'] = ""
|
402 |
if eval_name:
|
403 |
args["eval_name"] = eval_name
|
404 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
@@ -37,6 +37,11 @@ import functools
|
|
37 |
import torch
|
38 |
import numpy as np
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp_Adagn
|
42 |
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp_Adagn
|
@@ -63,6 +68,7 @@ class NCSNpp(nn.Module):
|
|
63 |
def __init__(self, config):
|
64 |
super().__init__()
|
65 |
self.config = config
|
|
|
66 |
self.not_use_tanh = config.not_use_tanh
|
67 |
self.act = act = nn.SiLU()
|
68 |
self.z_emb_dim = z_emb_dim = config.z_emb_dim
|
@@ -176,6 +182,8 @@ class NCSNpp(nn.Module):
|
|
176 |
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
177 |
|
178 |
# Downsampling block
|
|
|
|
|
179 |
|
180 |
channels = config.num_channels
|
181 |
if progressive_input != 'none':
|
@@ -189,18 +197,18 @@ class NCSNpp(nn.Module):
|
|
189 |
# Residual blocks for this resolution
|
190 |
for i_block in range(num_res_blocks):
|
191 |
out_ch = nf * ch_mult[i_level]
|
192 |
-
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
193 |
in_ch = out_ch
|
194 |
|
195 |
if all_resolutions[i_level] in attn_resolutions:
|
196 |
-
modules.append(AttnBlock(channels=in_ch))
|
197 |
hs_c.append(in_ch)
|
198 |
|
199 |
if i_level != num_resolutions - 1:
|
200 |
if resblock_type == 'ddpm':
|
201 |
modules.append(Downsample(in_ch=in_ch))
|
202 |
else:
|
203 |
-
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
204 |
|
205 |
if progressive_input == 'input_skip':
|
206 |
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
@@ -214,21 +222,21 @@ class NCSNpp(nn.Module):
|
|
214 |
hs_c.append(in_ch)
|
215 |
|
216 |
in_ch = hs_c[-1]
|
217 |
-
modules.append(ResnetBlock(in_ch=in_ch))
|
218 |
-
modules.append(AttnBlock(channels=in_ch))
|
219 |
-
modules.append(ResnetBlock(in_ch=in_ch))
|
220 |
|
221 |
pyramid_ch = 0
|
222 |
# Upsampling block
|
223 |
for i_level in reversed(range(num_resolutions)):
|
224 |
for i_block in range(num_res_blocks + 1):
|
225 |
out_ch = nf * ch_mult[i_level]
|
226 |
-
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(),
|
227 |
-
out_ch=out_ch))
|
228 |
in_ch = out_ch
|
229 |
|
230 |
if all_resolutions[i_level] in attn_resolutions:
|
231 |
-
modules.append(AttnBlock(channels=in_ch))
|
232 |
|
233 |
if progressive != 'none':
|
234 |
if i_level == num_resolutions - 1:
|
@@ -260,7 +268,7 @@ class NCSNpp(nn.Module):
|
|
260 |
if resblock_type == 'ddpm':
|
261 |
modules.append(Upsample(in_ch=in_ch))
|
262 |
else:
|
263 |
-
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
264 |
|
265 |
assert not hs_c
|
266 |
|
|
|
37 |
import torch
|
38 |
import numpy as np
|
39 |
|
40 |
+
try:
|
41 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
42 |
+
except Exception:
|
43 |
+
checkpoint_wrapper = lambda x:x
|
44 |
+
|
45 |
|
46 |
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp_Adagn
|
47 |
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp_Adagn
|
|
|
68 |
def __init__(self, config):
|
69 |
super().__init__()
|
70 |
self.config = config
|
71 |
+
self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
|
72 |
self.not_use_tanh = config.not_use_tanh
|
73 |
self.act = act = nn.SiLU()
|
74 |
self.z_emb_dim = z_emb_dim = config.z_emb_dim
|
|
|
182 |
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
183 |
|
184 |
# Downsampling block
|
185 |
+
def wrap(block):
|
186 |
+
return checkpoint_wrapper(block) if self.grad_checkpointing else block
|
187 |
|
188 |
channels = config.num_channels
|
189 |
if progressive_input != 'none':
|
|
|
197 |
# Residual blocks for this resolution
|
198 |
for i_block in range(num_res_blocks):
|
199 |
out_ch = nf * ch_mult[i_level]
|
200 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch, out_ch=out_ch)))
|
201 |
in_ch = out_ch
|
202 |
|
203 |
if all_resolutions[i_level] in attn_resolutions:
|
204 |
+
modules.append(wrap(AttnBlock(channels=in_ch)))
|
205 |
hs_c.append(in_ch)
|
206 |
|
207 |
if i_level != num_resolutions - 1:
|
208 |
if resblock_type == 'ddpm':
|
209 |
modules.append(Downsample(in_ch=in_ch))
|
210 |
else:
|
211 |
+
modules.append(wrap(ResnetBlock(down=True, in_ch=in_ch)))
|
212 |
|
213 |
if progressive_input == 'input_skip':
|
214 |
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
|
|
222 |
hs_c.append(in_ch)
|
223 |
|
224 |
in_ch = hs_c[-1]
|
225 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch)))
|
226 |
+
modules.append(wrap(AttnBlock(channels=in_ch)))
|
227 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch)))
|
228 |
|
229 |
pyramid_ch = 0
|
230 |
# Upsampling block
|
231 |
for i_level in reversed(range(num_resolutions)):
|
232 |
for i_block in range(num_res_blocks + 1):
|
233 |
out_ch = nf * ch_mult[i_level]
|
234 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch + hs_c.pop(),
|
235 |
+
out_ch=out_ch)))
|
236 |
in_ch = out_ch
|
237 |
|
238 |
if all_resolutions[i_level] in attn_resolutions:
|
239 |
+
modules.append(wrap(AttnBlock(channels=in_ch)))
|
240 |
|
241 |
if progressive != 'none':
|
242 |
if i_level == num_resolutions - 1:
|
|
|
268 |
if resblock_type == 'ddpm':
|
269 |
modules.append(Upsample(in_ch=in_ch))
|
270 |
else:
|
271 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch, up=True)))
|
272 |
|
273 |
assert not hs_c
|
274 |
|
test_ddgan.py
CHANGED
@@ -380,7 +380,11 @@ def sample_and_test(args):
|
|
380 |
epochs = range(1000)
|
381 |
else:
|
382 |
epochs = [args.epoch_id]
|
383 |
-
|
|
|
|
|
|
|
|
|
384 |
for epoch in epochs:
|
385 |
args.epoch_id = epoch
|
386 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
@@ -389,7 +393,7 @@ def sample_and_test(args):
|
|
389 |
continue
|
390 |
if not os.path.exists(next_next_path):
|
391 |
break
|
392 |
-
print(path)
|
393 |
|
394 |
#if not os.path.exists(next_path):
|
395 |
# print(f"STOP at {epoch}")
|
@@ -400,9 +404,7 @@ def sample_and_test(args):
|
|
400 |
continue
|
401 |
suffix = '_' + args.eval_name if args.eval_name else ""
|
402 |
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
|
403 |
-
|
404 |
-
|
405 |
-
if (args.compute_fid or args.compute_clip_score) and os.path.exists(dest):
|
406 |
continue
|
407 |
print("Eval Epoch", args.epoch_id)
|
408 |
#loading weights from ddp in single gpu
|
@@ -424,7 +426,8 @@ def sample_and_test(args):
|
|
424 |
if not os.path.exists(save_dir):
|
425 |
os.makedirs(save_dir)
|
426 |
|
427 |
-
|
|
|
428 |
from torch.nn.functional import adaptive_avg_pool2d
|
429 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
430 |
from pytorch_fid.inception import InceptionV3
|
@@ -472,6 +475,8 @@ def sample_and_test(args):
|
|
472 |
|
473 |
if args.compute_clip_score:
|
474 |
clip_scores = []
|
|
|
|
|
475 |
|
476 |
for b in range(0, len(texts), args.batch_size):
|
477 |
text = texts[b:b+args.batch_size]
|
@@ -485,12 +490,7 @@ def sample_and_test(args):
|
|
485 |
else:
|
486 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
487 |
fake_sample = to_range_0_1(fake_sample)
|
488 |
-
|
489 |
-
for j, x in enumerate(fake_sample):
|
490 |
-
index = i * args.batch_size + j
|
491 |
-
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
492 |
-
"""
|
493 |
-
|
494 |
if args.compute_fid:
|
495 |
with torch.no_grad():
|
496 |
pred = inceptionv3(fake_sample)[0]
|
@@ -511,9 +511,18 @@ def sample_and_test(args):
|
|
511 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
512 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
513 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
if i % 10 == 0:
|
516 |
print('evaluating batch ', i, time.time() - t0)
|
|
|
517 |
i += 1
|
518 |
|
519 |
results = {}
|
@@ -526,6 +535,9 @@ def sample_and_test(args):
|
|
526 |
if args.compute_clip_score:
|
527 |
clip_score = torch.cat(clip_scores).mean().item()
|
528 |
results['clip_score'] = clip_score
|
|
|
|
|
|
|
529 |
results.update(vars(args))
|
530 |
with open(dest, "w") as fd:
|
531 |
json.dump(results, fd)
|
@@ -591,6 +603,9 @@ if __name__ == '__main__':
|
|
591 |
help='whether or not compute FID')
|
592 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
593 |
help='whether or not compute CLIP score')
|
|
|
|
|
|
|
594 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
595 |
parser.add_argument('--eval_name', type=str,default="")
|
596 |
|
@@ -625,7 +640,7 @@ if __name__ == '__main__':
|
|
625 |
|
626 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
627 |
help='number of resnet blocks per scale')
|
628 |
-
parser.add_argument('--attn_resolutions', default=(16,),
|
629 |
help='resolution of applying attention')
|
630 |
parser.add_argument('--dropout', type=float, default=0.,
|
631 |
help='drop-out rate')
|
|
|
380 |
epochs = range(1000)
|
381 |
else:
|
382 |
epochs = [args.epoch_id]
|
383 |
+
if args.compute_image_reward:
|
384 |
+
import ImageReward as RM
|
385 |
+
#image_reward = RM.load("ImageReward-v1.0", download_root=".").to(device)
|
386 |
+
image_reward = RM.load("ImageReward.pt", download_root=".").to(device)
|
387 |
+
|
388 |
for epoch in epochs:
|
389 |
args.epoch_id = epoch
|
390 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
|
|
393 |
continue
|
394 |
if not os.path.exists(next_next_path):
|
395 |
break
|
396 |
+
print("PATH", path)
|
397 |
|
398 |
#if not os.path.exists(next_path):
|
399 |
# print(f"STOP at {epoch}")
|
|
|
404 |
continue
|
405 |
suffix = '_' + args.eval_name if args.eval_name else ""
|
406 |
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
|
407 |
+
if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
|
|
|
|
|
408 |
continue
|
409 |
print("Eval Epoch", args.epoch_id)
|
410 |
#loading weights from ddp in single gpu
|
|
|
426 |
if not os.path.exists(save_dir):
|
427 |
os.makedirs(save_dir)
|
428 |
|
429 |
+
|
430 |
+
if args.compute_fid or args.compute_clip_score or args.compute_image_reward:
|
431 |
from torch.nn.functional import adaptive_avg_pool2d
|
432 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
433 |
from pytorch_fid.inception import InceptionV3
|
|
|
475 |
|
476 |
if args.compute_clip_score:
|
477 |
clip_scores = []
|
478 |
+
if args.compute_image_reward:
|
479 |
+
image_rewards = []
|
480 |
|
481 |
for b in range(0, len(texts), args.batch_size):
|
482 |
text = texts[b:b+args.batch_size]
|
|
|
490 |
else:
|
491 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
492 |
fake_sample = to_range_0_1(fake_sample)
|
493 |
+
|
|
|
|
|
|
|
|
|
|
|
494 |
if args.compute_fid:
|
495 |
with torch.no_grad():
|
496 |
pred = inceptionv3(fake_sample)[0]
|
|
|
511 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
512 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
513 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
514 |
+
|
515 |
+
if args.compute_image_reward:
|
516 |
+
for k, sample in enumerate(fake_sample):
|
517 |
+
img = sample.cpu().numpy().transpose(1,2,0)
|
518 |
+
img = img * 255
|
519 |
+
img = img.astype(np.uint8)
|
520 |
+
text_k = text[k]
|
521 |
+
score = image_reward.score(text_k, img)
|
522 |
+
image_rewards.append(score)
|
523 |
if i % 10 == 0:
|
524 |
print('evaluating batch ', i, time.time() - t0)
|
525 |
+
#break
|
526 |
i += 1
|
527 |
|
528 |
results = {}
|
|
|
535 |
if args.compute_clip_score:
|
536 |
clip_score = torch.cat(clip_scores).mean().item()
|
537 |
results['clip_score'] = clip_score
|
538 |
+
if args.compute_image_reward:
|
539 |
+
reward = np.mean(image_rewards)
|
540 |
+
results['image_reward'] = reward
|
541 |
results.update(vars(args))
|
542 |
with open(dest, "w") as fd:
|
543 |
json.dump(results, fd)
|
|
|
603 |
help='whether or not compute FID')
|
604 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
605 |
help='whether or not compute CLIP score')
|
606 |
+
parser.add_argument('--compute_image_reward', action='store_true', default=False,
|
607 |
+
help='whether or not compute CLIP score')
|
608 |
+
|
609 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
610 |
parser.add_argument('--eval_name', type=str,default="")
|
611 |
|
|
|
640 |
|
641 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
642 |
help='number of resnet blocks per scale')
|
643 |
+
parser.add_argument('--attn_resolutions', default=(16,), nargs='+', type=int,
|
644 |
help='resolution of applying attention')
|
645 |
parser.add_argument('--dropout', type=float, default=0.,
|
646 |
help='drop-out rate')
|
train_ddgan.py
CHANGED
@@ -4,14 +4,14 @@
|
|
4 |
# This work is licensed under the NVIDIA Source Code License
|
5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
6 |
# ---------------------------------------------------------------
|
|
|
7 |
|
8 |
from glob import glob
|
9 |
import argparse
|
10 |
-
import torch
|
11 |
import numpy as np
|
12 |
-
|
13 |
import os
|
14 |
-
|
15 |
import torch.nn as nn
|
16 |
import torch.nn.functional as F
|
17 |
import torch.optim as optim
|
@@ -288,6 +288,15 @@ def train(rank, gpu, args):
|
|
288 |
transforms.ToTensor(),
|
289 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
290 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
|
292 |
pipeline = [ResampledShards2(shards)]
|
293 |
pipeline.extend([
|
@@ -312,7 +321,7 @@ def train(rank, gpu, args):
|
|
312 |
dataset,
|
313 |
batch_size=None,
|
314 |
shuffle=False,
|
315 |
-
num_workers=
|
316 |
)
|
317 |
|
318 |
if args.dataset != "wds":
|
@@ -355,6 +364,7 @@ def train(rank, gpu, args):
|
|
355 |
cond_size=text_encoder.output_size,
|
356 |
act=nn.LeakyReLU(0.2)).to(device)
|
357 |
elif args.discr_type == "large_attn_pool":
|
|
|
358 |
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
359 |
t_emb_dim = args.t_emb_dim,
|
360 |
cond_size=text_encoder.output_size,
|
@@ -362,6 +372,7 @@ def train(rank, gpu, args):
|
|
362 |
act=nn.LeakyReLU(0.2)).to(device)
|
363 |
|
364 |
elif args.discr_type == "large_cond_attn":
|
|
|
365 |
netD = CondAttnDiscriminator(
|
366 |
nc = 2*args.num_channels,
|
367 |
ngf = args.ngf,
|
@@ -391,7 +402,7 @@ def train(rank, gpu, args):
|
|
391 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
392 |
|
393 |
if args.use_ema:
|
394 |
-
optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
|
395 |
|
396 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
397 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
@@ -403,12 +414,10 @@ def train(rank, gpu, args):
|
|
403 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
|
404 |
#if args.discr_type == "projected_gan":
|
405 |
# netD._set_static_graph()
|
406 |
-
|
407 |
|
408 |
-
if args.grad_checkpointing:
|
409 |
-
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
410 |
-
netG = checkpoint_wrapper(netG)
|
411 |
-
|
412 |
exp = args.exp
|
413 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
414 |
|
@@ -442,8 +451,9 @@ def train(rank, gpu, args):
|
|
442 |
optimizerD.load_state_dict(checkpoint['optimizerD'])
|
443 |
schedulerD.load_state_dict(checkpoint['schedulerD'])
|
444 |
global_step = checkpoint['global_step']
|
445 |
-
|
446 |
-
|
|
|
447 |
else:
|
448 |
global_step, epoch, init_epoch = 0, 0, 0
|
449 |
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool", "projected_gan")
|
@@ -454,6 +464,7 @@ def train(rank, gpu, args):
|
|
454 |
train_sampler.set_epoch(epoch)
|
455 |
|
456 |
for iteration, (x, y) in enumerate(data_loader):
|
|
|
457 |
#print(x.shape)
|
458 |
if args.dataset != "wds":
|
459 |
y = [str(yi) for yi in y.tolist()]
|
@@ -631,6 +642,8 @@ def train(rank, gpu, args):
|
|
631 |
if rank == 0:
|
632 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
633 |
print('Global step:', global_step)
|
|
|
|
|
634 |
if iteration % 1000 == 0:
|
635 |
x_t_1 = torch.randn_like(real_data)
|
636 |
with autocast():
|
@@ -640,7 +653,8 @@ def train(rank, gpu, args):
|
|
640 |
|
641 |
if args.save_content:
|
642 |
dist.barrier()
|
643 |
-
|
|
|
644 |
def to_cpu(d):
|
645 |
for k, v in d.items():
|
646 |
d[k] = v.cpu()
|
@@ -677,6 +691,9 @@ def train(rank, gpu, args):
|
|
677 |
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
678 |
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
679 |
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
|
|
|
|
|
|
680 |
if args.use_ema:
|
681 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
682 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
@@ -685,40 +702,8 @@ def train(rank, gpu, args):
|
|
685 |
|
686 |
|
687 |
if not args.no_lr_decay:
|
688 |
-
|
689 |
schedulerG.step()
|
690 |
schedulerD.step()
|
691 |
-
"""
|
692 |
-
if rank == 0:
|
693 |
-
if epoch % 10 == 0:
|
694 |
-
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
695 |
-
|
696 |
-
x_t_1 = torch.randn_like(real_data)
|
697 |
-
with autocast():
|
698 |
-
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
699 |
-
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
700 |
-
|
701 |
-
if args.save_content:
|
702 |
-
if epoch % args.save_content_every == 0:
|
703 |
-
print('Saving content.')
|
704 |
-
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
705 |
-
'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
|
706 |
-
'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
|
707 |
-
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
708 |
-
|
709 |
-
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
710 |
-
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
711 |
-
|
712 |
-
if epoch % args.save_ckpt_every == 0:
|
713 |
-
if args.use_ema:
|
714 |
-
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
715 |
-
|
716 |
-
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
717 |
-
if args.use_ema:
|
718 |
-
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
719 |
-
dist.barrier()
|
720 |
-
"""
|
721 |
-
|
722 |
|
723 |
def init_processes(rank, size, fn, args):
|
724 |
""" Initialize the distributed environment. """
|
@@ -748,12 +733,12 @@ if __name__ == '__main__':
|
|
748 |
help='seed used for initialization')
|
749 |
|
750 |
parser.add_argument('--resume', action='store_true',default=False)
|
751 |
-
parser.add_argument('--masked_mean', action='store_true',default=False)
|
752 |
-
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
753 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
754 |
-
parser.add_argument('--cross_attention', action='store_true',default=False)
|
755 |
-
parser.add_argument('--fsdp', action='store_true',default=False)
|
756 |
-
parser.add_argument('--grad_checkpointing', action='store_true',default=False)
|
757 |
|
758 |
parser.add_argument('--image_size', type=int, default=32,
|
759 |
help='size of image')
|
@@ -767,9 +752,8 @@ if __name__ == '__main__':
|
|
767 |
parser.add_argument('--beta_max', type=float, default=20.,
|
768 |
help='beta_max for diffusion')
|
769 |
parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
|
770 |
-
|
771 |
parser.add_argument('--num_channels_dae', type=int, default=128,
|
772 |
-
help='number of initial channels in denosing model')
|
773 |
parser.add_argument('--n_mlp', type=int, default=3,
|
774 |
help='number of mlp layers for z')
|
775 |
parser.add_argument('--ch_mult', nargs='+', type=int,
|
@@ -825,7 +809,7 @@ if __name__ == '__main__':
|
|
825 |
parser.add_argument('--beta2', type=float, default=0.9,
|
826 |
help='beta2 for adam')
|
827 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
828 |
-
parser.add_argument('--grad_penalty_cond', action='store_true',default=False)
|
829 |
|
830 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
831 |
help='use EMA or not')
|
|
|
4 |
# This work is licensed under the NVIDIA Source Code License
|
5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
6 |
# ---------------------------------------------------------------
|
7 |
+
import torch
|
8 |
|
9 |
from glob import glob
|
10 |
import argparse
|
|
|
11 |
import numpy as np
|
12 |
+
import json
|
13 |
import os
|
14 |
+
import time
|
15 |
import torch.nn as nn
|
16 |
import torch.nn.functional as F
|
17 |
import torch.optim as optim
|
|
|
288 |
transforms.ToTensor(),
|
289 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
290 |
])
|
291 |
+
elif args.preprocessing == "simple_random_crop_v2":
|
292 |
+
train_transform = transforms.Compose([
|
293 |
+
transforms.Resize(args.image_size),
|
294 |
+
transforms.RandomCrop(args.image_size, interpolation=3),
|
295 |
+
transforms.ToTensor(),
|
296 |
+
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
297 |
+
])
|
298 |
+
else:
|
299 |
+
raise ValueError(args.preprocessing)
|
300 |
shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
|
301 |
pipeline = [ResampledShards2(shards)]
|
302 |
pipeline.extend([
|
|
|
321 |
dataset,
|
322 |
batch_size=None,
|
323 |
shuffle=False,
|
324 |
+
num_workers=1,
|
325 |
)
|
326 |
|
327 |
if args.dataset != "wds":
|
|
|
364 |
cond_size=text_encoder.output_size,
|
365 |
act=nn.LeakyReLU(0.2)).to(device)
|
366 |
elif args.discr_type == "large_attn_pool":
|
367 |
+
# Discriminator with Attention Pool based discriminator for text conditioning
|
368 |
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
369 |
t_emb_dim = args.t_emb_dim,
|
370 |
cond_size=text_encoder.output_size,
|
|
|
372 |
act=nn.LeakyReLU(0.2)).to(device)
|
373 |
|
374 |
elif args.discr_type == "large_cond_attn":
|
375 |
+
# Discriminator with Cross-Attention based discriminator for text conditioning
|
376 |
netD = CondAttnDiscriminator(
|
377 |
nc = 2*args.num_channels,
|
378 |
ngf = args.ngf,
|
|
|
402 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
403 |
|
404 |
if args.use_ema:
|
405 |
+
optimizerG = EMA(optimizerG, ema_decay=args.ema_decay, memory_efficient=args.grad_checkpointing)
|
406 |
|
407 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
408 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
|
|
414 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
|
415 |
#if args.discr_type == "projected_gan":
|
416 |
# netD._set_static_graph()
|
|
|
417 |
|
418 |
+
#if args.grad_checkpointing:
|
419 |
+
#from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
420 |
+
#netG = checkpoint_wrapper(netG)
|
|
|
421 |
exp = args.exp
|
422 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
423 |
|
|
|
451 |
optimizerD.load_state_dict(checkpoint['optimizerD'])
|
452 |
schedulerD.load_state_dict(checkpoint['schedulerD'])
|
453 |
global_step = checkpoint['global_step']
|
454 |
+
if rank == 0:
|
455 |
+
print("=> loaded checkpoint (epoch {})"
|
456 |
+
.format(checkpoint['epoch']))
|
457 |
else:
|
458 |
global_step, epoch, init_epoch = 0, 0, 0
|
459 |
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool", "projected_gan")
|
|
|
464 |
train_sampler.set_epoch(epoch)
|
465 |
|
466 |
for iteration, (x, y) in enumerate(data_loader):
|
467 |
+
t0 = time.time()
|
468 |
#print(x.shape)
|
469 |
if args.dataset != "wds":
|
470 |
y = [str(yi) for yi in y.tolist()]
|
|
|
642 |
if rank == 0:
|
643 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
644 |
print('Global step:', global_step)
|
645 |
+
dt = time.time() - t0
|
646 |
+
print('Time per iteration: ', dt)
|
647 |
if iteration % 1000 == 0:
|
648 |
x_t_1 = torch.randn_like(real_data)
|
649 |
with autocast():
|
|
|
653 |
|
654 |
if args.save_content:
|
655 |
dist.barrier()
|
656 |
+
if rank == 0:
|
657 |
+
print('Saving content.')
|
658 |
def to_cpu(d):
|
659 |
for k, v in d.items():
|
660 |
d[k] = v.cpu()
|
|
|
691 |
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
692 |
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
693 |
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
694 |
+
state_content = {'epoch': epoch + 1, 'global_step': global_step}
|
695 |
+
with open(os.path.join(exp_path, 'netG_{}.json'.format(epoch)), "w") as fd:
|
696 |
+
fd.write(json.dumps(state_content))
|
697 |
if args.use_ema:
|
698 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
699 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
|
|
702 |
|
703 |
|
704 |
if not args.no_lr_decay:
|
|
|
705 |
schedulerG.step()
|
706 |
schedulerD.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
|
708 |
def init_processes(rank, size, fn, args):
|
709 |
""" Initialize the distributed environment. """
|
|
|
733 |
help='seed used for initialization')
|
734 |
|
735 |
parser.add_argument('--resume', action='store_true',default=False)
|
736 |
+
parser.add_argument('--masked_mean', action='store_true',default=False, help="use masked mean to pool from t5-based text encoder")
|
737 |
+
parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
|
738 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
739 |
+
parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
|
740 |
+
parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
|
741 |
+
parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
|
742 |
|
743 |
parser.add_argument('--image_size', type=int, default=32,
|
744 |
help='size of image')
|
|
|
752 |
parser.add_argument('--beta_max', type=float, default=20.,
|
753 |
help='beta_max for diffusion')
|
754 |
parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
|
|
|
755 |
parser.add_argument('--num_channels_dae', type=int, default=128,
|
756 |
+
help='number of initial channels in denosing model generator')
|
757 |
parser.add_argument('--n_mlp', type=int, default=3,
|
758 |
help='number of mlp layers for z')
|
759 |
parser.add_argument('--ch_mult', nargs='+', type=int,
|
|
|
809 |
parser.add_argument('--beta2', type=float, default=0.9,
|
810 |
help='beta2 for adam')
|
811 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
812 |
+
parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad penalty")
|
813 |
|
814 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
815 |
help='use EMA or not')
|