Mehdi Cherti commited on
Commit
3dcdf92
1 Parent(s): e96a195

- memory efficient EMA

Browse files

- fix gradient checkpointing
- evaluate using image reward paper
- use attn_resolution in test

Files changed (5) hide show
  1. EMA.py +13 -4
  2. run.py +29 -1
  3. score_sde/models/ncsnpp_generator_adagn.py +18 -10
  4. test_ddgan.py +29 -14
  5. train_ddgan.py +37 -53
EMA.py CHANGED
@@ -15,13 +15,14 @@ from torch.optim import Optimizer
15
 
16
 
17
  class EMA(Optimizer):
18
- def __init__(self, opt, ema_decay):
19
  self.ema_decay = ema_decay
20
  self.apply_ema = self.ema_decay > 0.
21
  self.optimizer = opt
22
  self.state = opt.state
23
  self.param_groups = opt.param_groups
24
  self.defaults = {}
 
25
 
26
  def step(self, *args, **kwargs):
27
  # for group in self.optimizer.param_groups:
@@ -53,11 +54,19 @@ class EMA(Optimizer):
53
 
54
  params[p.shape]['data'].append(p.data)
55
  ema[p.shape].append(state['ema'])
 
 
 
56
 
57
  for i in params:
58
- params[i]['data'] = torch.stack(params[i]['data'], dim=0)
59
- ema[i] = torch.stack(ema[i], dim=0)
60
- ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)
 
 
 
 
 
61
 
62
  for p in group['params']:
63
  if p.grad is None:
 
15
 
16
 
17
  class EMA(Optimizer):
18
+ def __init__(self, opt, ema_decay, memory_efficient=False):
19
  self.ema_decay = ema_decay
20
  self.apply_ema = self.ema_decay > 0.
21
  self.optimizer = opt
22
  self.state = opt.state
23
  self.param_groups = opt.param_groups
24
  self.defaults = {}
25
+ self.memory_efficient = memory_efficient
26
 
27
  def step(self, *args, **kwargs):
28
  # for group in self.optimizer.param_groups:
 
54
 
55
  params[p.shape]['data'].append(p.data)
56
  ema[p.shape].append(state['ema'])
57
+
58
+ # def stack(d, dim=0):
59
+ # return torch.stack([di.cpu() for di in d], dim=dim).cuda()
60
 
61
  for i in params:
62
+ if self.memory_efficient:
63
+ for j in range(len(params[i]['data'])):
64
+ ema[i][j].mul_(self.ema_decay).add_(params[i]['data'][j], alpha=1. - self.ema_decay)
65
+ ema[i] = torch.stack(ema[i], dim=0)
66
+ else:
67
+ params[i]['data'] = torch.stack(params[i]['data'], dim=0)
68
+ ema[i] = torch.stack(ema[i], dim=0)
69
+ ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)
70
 
71
  for p in group['params']:
72
  if p.grad is None:
run.py CHANGED
@@ -274,10 +274,30 @@ def ddgan_ddb_v7():
274
  cfg = ddgan_ddb_v1()
275
  return cfg
276
 
 
 
 
 
 
277
  def ddgan_laion_aesthetic_v15():
278
  cfg = ddgan_ddb_v3()
279
  return cfg
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  models = [
282
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
283
  ddgan_cifar10_cond18, # cifar10, xl encoder
@@ -326,6 +346,10 @@ models = [
326
  ddgan_ddb_v5,
327
  ddgan_ddb_v6,
328
  ddgan_ddb_v7,
 
 
 
 
329
  ]
330
 
331
  def get_model(model_name):
@@ -334,7 +358,7 @@ def get_model(model_name):
334
  return model()
335
 
336
 
337
- def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional"):
338
 
339
  cfg = get_model(model_name)
340
  model = cfg['model']
@@ -365,12 +389,16 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
365
  args['scale_factor_w'] = scale_factor_w
366
  args['n_mlp'] = model.get("n_mlp")
367
  args['scale_method'] = scale_method
 
368
  if fid:
369
  args['compute_fid'] = ''
370
  args['real_img_dir'] = real_img_dir
371
  args['nb_images_for_fid'] = nb_images_for_fid
372
  if compute_clip_score:
373
  args['compute_clip_score'] = ""
 
 
 
374
  if eval_name:
375
  args["eval_name"] = eval_name
376
  cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
 
274
  cfg = ddgan_ddb_v1()
275
  return cfg
276
 
277
+ def ddgan_ddb_v9():
278
+ cfg = ddgan_ddb_v3()
279
+ cfg['model']['attn_resolutions'] = '4 8 16 32'
280
+ return cfg
281
+
282
  def ddgan_laion_aesthetic_v15():
283
  cfg = ddgan_ddb_v3()
284
  return cfg
285
 
286
+ def ddgan_ddb_v10():
287
+ cfg = ddgan_ddb_v9()
288
+ return cfg
289
+
290
+ def ddgan_ddb_v11():
291
+ cfg = ddgan_ddb_v3()
292
+ cfg['model']['text_encoder'] = "openclip/ViT-g-14/laion2B-s12B-b42K"
293
+ return cfg
294
+
295
+ def ddgan_ddb_v12():
296
+ cfg = ddgan_ddb_v3()
297
+ cfg['model']['text_encoder'] = "openclip/ViT-bigG-14/laion2b_s39b_b160k"
298
+ return cfg
299
+
300
+
301
  models = [
302
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
303
  ddgan_cifar10_cond18, # cifar10, xl encoder
 
346
  ddgan_ddb_v5,
347
  ddgan_ddb_v6,
348
  ddgan_ddb_v7,
349
+ ddgan_ddb_v9,
350
+ ddgan_ddb_v10,
351
+ ddgan_ddb_v11,
352
+ ddgan_ddb_v12,
353
  ]
354
 
355
  def get_model(model_name):
 
358
  return model()
359
 
360
 
361
+ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional", compute_image_reward=False):
362
 
363
  cfg = get_model(model_name)
364
  model = cfg['model']
 
389
  args['scale_factor_w'] = scale_factor_w
390
  args['n_mlp'] = model.get("n_mlp")
391
  args['scale_method'] = scale_method
392
+ args['attn_resolutions'] = model.get("attn_resolutions", "16")
393
  if fid:
394
  args['compute_fid'] = ''
395
  args['real_img_dir'] = real_img_dir
396
  args['nb_images_for_fid'] = nb_images_for_fid
397
  if compute_clip_score:
398
  args['compute_clip_score'] = ""
399
+
400
+ if compute_image_reward:
401
+ args['compute_image_reward'] = ""
402
  if eval_name:
403
  args["eval_name"] = eval_name
404
  cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
score_sde/models/ncsnpp_generator_adagn.py CHANGED
@@ -37,6 +37,11 @@ import functools
37
  import torch
38
  import numpy as np
39
 
 
 
 
 
 
40
 
41
  ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp_Adagn
42
  ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp_Adagn
@@ -63,6 +68,7 @@ class NCSNpp(nn.Module):
63
  def __init__(self, config):
64
  super().__init__()
65
  self.config = config
 
66
  self.not_use_tanh = config.not_use_tanh
67
  self.act = act = nn.SiLU()
68
  self.z_emb_dim = z_emb_dim = config.z_emb_dim
@@ -176,6 +182,8 @@ class NCSNpp(nn.Module):
176
  raise ValueError(f'resblock type {resblock_type} unrecognized.')
177
 
178
  # Downsampling block
 
 
179
 
180
  channels = config.num_channels
181
  if progressive_input != 'none':
@@ -189,18 +197,18 @@ class NCSNpp(nn.Module):
189
  # Residual blocks for this resolution
190
  for i_block in range(num_res_blocks):
191
  out_ch = nf * ch_mult[i_level]
192
- modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
193
  in_ch = out_ch
194
 
195
  if all_resolutions[i_level] in attn_resolutions:
196
- modules.append(AttnBlock(channels=in_ch))
197
  hs_c.append(in_ch)
198
 
199
  if i_level != num_resolutions - 1:
200
  if resblock_type == 'ddpm':
201
  modules.append(Downsample(in_ch=in_ch))
202
  else:
203
- modules.append(ResnetBlock(down=True, in_ch=in_ch))
204
 
205
  if progressive_input == 'input_skip':
206
  modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
@@ -214,21 +222,21 @@ class NCSNpp(nn.Module):
214
  hs_c.append(in_ch)
215
 
216
  in_ch = hs_c[-1]
217
- modules.append(ResnetBlock(in_ch=in_ch))
218
- modules.append(AttnBlock(channels=in_ch))
219
- modules.append(ResnetBlock(in_ch=in_ch))
220
 
221
  pyramid_ch = 0
222
  # Upsampling block
223
  for i_level in reversed(range(num_resolutions)):
224
  for i_block in range(num_res_blocks + 1):
225
  out_ch = nf * ch_mult[i_level]
226
- modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(),
227
- out_ch=out_ch))
228
  in_ch = out_ch
229
 
230
  if all_resolutions[i_level] in attn_resolutions:
231
- modules.append(AttnBlock(channels=in_ch))
232
 
233
  if progressive != 'none':
234
  if i_level == num_resolutions - 1:
@@ -260,7 +268,7 @@ class NCSNpp(nn.Module):
260
  if resblock_type == 'ddpm':
261
  modules.append(Upsample(in_ch=in_ch))
262
  else:
263
- modules.append(ResnetBlock(in_ch=in_ch, up=True))
264
 
265
  assert not hs_c
266
 
 
37
  import torch
38
  import numpy as np
39
 
40
+ try:
41
+ from fairscale.nn.checkpoint import checkpoint_wrapper
42
+ except Exception:
43
+ checkpoint_wrapper = lambda x:x
44
+
45
 
46
  ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp_Adagn
47
  ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp_Adagn
 
68
  def __init__(self, config):
69
  super().__init__()
70
  self.config = config
71
+ self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
72
  self.not_use_tanh = config.not_use_tanh
73
  self.act = act = nn.SiLU()
74
  self.z_emb_dim = z_emb_dim = config.z_emb_dim
 
182
  raise ValueError(f'resblock type {resblock_type} unrecognized.')
183
 
184
  # Downsampling block
185
+ def wrap(block):
186
+ return checkpoint_wrapper(block) if self.grad_checkpointing else block
187
 
188
  channels = config.num_channels
189
  if progressive_input != 'none':
 
197
  # Residual blocks for this resolution
198
  for i_block in range(num_res_blocks):
199
  out_ch = nf * ch_mult[i_level]
200
+ modules.append(wrap(ResnetBlock(in_ch=in_ch, out_ch=out_ch)))
201
  in_ch = out_ch
202
 
203
  if all_resolutions[i_level] in attn_resolutions:
204
+ modules.append(wrap(AttnBlock(channels=in_ch)))
205
  hs_c.append(in_ch)
206
 
207
  if i_level != num_resolutions - 1:
208
  if resblock_type == 'ddpm':
209
  modules.append(Downsample(in_ch=in_ch))
210
  else:
211
+ modules.append(wrap(ResnetBlock(down=True, in_ch=in_ch)))
212
 
213
  if progressive_input == 'input_skip':
214
  modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
 
222
  hs_c.append(in_ch)
223
 
224
  in_ch = hs_c[-1]
225
+ modules.append(wrap(ResnetBlock(in_ch=in_ch)))
226
+ modules.append(wrap(AttnBlock(channels=in_ch)))
227
+ modules.append(wrap(ResnetBlock(in_ch=in_ch)))
228
 
229
  pyramid_ch = 0
230
  # Upsampling block
231
  for i_level in reversed(range(num_resolutions)):
232
  for i_block in range(num_res_blocks + 1):
233
  out_ch = nf * ch_mult[i_level]
234
+ modules.append(wrap(ResnetBlock(in_ch=in_ch + hs_c.pop(),
235
+ out_ch=out_ch)))
236
  in_ch = out_ch
237
 
238
  if all_resolutions[i_level] in attn_resolutions:
239
+ modules.append(wrap(AttnBlock(channels=in_ch)))
240
 
241
  if progressive != 'none':
242
  if i_level == num_resolutions - 1:
 
268
  if resblock_type == 'ddpm':
269
  modules.append(Upsample(in_ch=in_ch))
270
  else:
271
+ modules.append(wrap(ResnetBlock(in_ch=in_ch, up=True)))
272
 
273
  assert not hs_c
274
 
test_ddgan.py CHANGED
@@ -380,7 +380,11 @@ def sample_and_test(args):
380
  epochs = range(1000)
381
  else:
382
  epochs = [args.epoch_id]
383
-
 
 
 
 
384
  for epoch in epochs:
385
  args.epoch_id = epoch
386
  path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
@@ -389,7 +393,7 @@ def sample_and_test(args):
389
  continue
390
  if not os.path.exists(next_next_path):
391
  break
392
- print(path)
393
 
394
  #if not os.path.exists(next_path):
395
  # print(f"STOP at {epoch}")
@@ -400,9 +404,7 @@ def sample_and_test(args):
400
  continue
401
  suffix = '_' + args.eval_name if args.eval_name else ""
402
  dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
403
- next_dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id+1, suffix)
404
-
405
- if (args.compute_fid or args.compute_clip_score) and os.path.exists(dest):
406
  continue
407
  print("Eval Epoch", args.epoch_id)
408
  #loading weights from ddp in single gpu
@@ -424,7 +426,8 @@ def sample_and_test(args):
424
  if not os.path.exists(save_dir):
425
  os.makedirs(save_dir)
426
 
427
- if args.compute_fid or args.compute_clip_score:
 
428
  from torch.nn.functional import adaptive_avg_pool2d
429
  from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
430
  from pytorch_fid.inception import InceptionV3
@@ -472,6 +475,8 @@ def sample_and_test(args):
472
 
473
  if args.compute_clip_score:
474
  clip_scores = []
 
 
475
 
476
  for b in range(0, len(texts), args.batch_size):
477
  text = texts[b:b+args.batch_size]
@@ -485,12 +490,7 @@ def sample_and_test(args):
485
  else:
486
  fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
487
  fake_sample = to_range_0_1(fake_sample)
488
- """
489
- for j, x in enumerate(fake_sample):
490
- index = i * args.batch_size + j
491
- torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
492
- """
493
-
494
  if args.compute_fid:
495
  with torch.no_grad():
496
  pred = inceptionv3(fake_sample)[0]
@@ -511,9 +511,18 @@ def sample_and_test(args):
511
  imf = torch.nn.functional.normalize(imf, dim=1)
512
  txtf = torch.nn.functional.normalize(txtf, dim=1)
513
  clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
514
-
 
 
 
 
 
 
 
 
515
  if i % 10 == 0:
516
  print('evaluating batch ', i, time.time() - t0)
 
517
  i += 1
518
 
519
  results = {}
@@ -526,6 +535,9 @@ def sample_and_test(args):
526
  if args.compute_clip_score:
527
  clip_score = torch.cat(clip_scores).mean().item()
528
  results['clip_score'] = clip_score
 
 
 
529
  results.update(vars(args))
530
  with open(dest, "w") as fd:
531
  json.dump(results, fd)
@@ -591,6 +603,9 @@ if __name__ == '__main__':
591
  help='whether or not compute FID')
592
  parser.add_argument('--compute_clip_score', action='store_true', default=False,
593
  help='whether or not compute CLIP score')
 
 
 
594
  parser.add_argument('--clip_model', type=str,default="ViT-L/14")
595
  parser.add_argument('--eval_name', type=str,default="")
596
 
@@ -625,7 +640,7 @@ if __name__ == '__main__':
625
 
626
  parser.add_argument('--num_res_blocks', type=int, default=2,
627
  help='number of resnet blocks per scale')
628
- parser.add_argument('--attn_resolutions', default=(16,),
629
  help='resolution of applying attention')
630
  parser.add_argument('--dropout', type=float, default=0.,
631
  help='drop-out rate')
 
380
  epochs = range(1000)
381
  else:
382
  epochs = [args.epoch_id]
383
+ if args.compute_image_reward:
384
+ import ImageReward as RM
385
+ #image_reward = RM.load("ImageReward-v1.0", download_root=".").to(device)
386
+ image_reward = RM.load("ImageReward.pt", download_root=".").to(device)
387
+
388
  for epoch in epochs:
389
  args.epoch_id = epoch
390
  path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
 
393
  continue
394
  if not os.path.exists(next_next_path):
395
  break
396
+ print("PATH", path)
397
 
398
  #if not os.path.exists(next_path):
399
  # print(f"STOP at {epoch}")
 
404
  continue
405
  suffix = '_' + args.eval_name if args.eval_name else ""
406
  dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
407
+ if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
 
 
408
  continue
409
  print("Eval Epoch", args.epoch_id)
410
  #loading weights from ddp in single gpu
 
426
  if not os.path.exists(save_dir):
427
  os.makedirs(save_dir)
428
 
429
+
430
+ if args.compute_fid or args.compute_clip_score or args.compute_image_reward:
431
  from torch.nn.functional import adaptive_avg_pool2d
432
  from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
433
  from pytorch_fid.inception import InceptionV3
 
475
 
476
  if args.compute_clip_score:
477
  clip_scores = []
478
+ if args.compute_image_reward:
479
+ image_rewards = []
480
 
481
  for b in range(0, len(texts), args.batch_size):
482
  text = texts[b:b+args.batch_size]
 
490
  else:
491
  fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
492
  fake_sample = to_range_0_1(fake_sample)
493
+
 
 
 
 
 
494
  if args.compute_fid:
495
  with torch.no_grad():
496
  pred = inceptionv3(fake_sample)[0]
 
511
  imf = torch.nn.functional.normalize(imf, dim=1)
512
  txtf = torch.nn.functional.normalize(txtf, dim=1)
513
  clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
514
+
515
+ if args.compute_image_reward:
516
+ for k, sample in enumerate(fake_sample):
517
+ img = sample.cpu().numpy().transpose(1,2,0)
518
+ img = img * 255
519
+ img = img.astype(np.uint8)
520
+ text_k = text[k]
521
+ score = image_reward.score(text_k, img)
522
+ image_rewards.append(score)
523
  if i % 10 == 0:
524
  print('evaluating batch ', i, time.time() - t0)
525
+ #break
526
  i += 1
527
 
528
  results = {}
 
535
  if args.compute_clip_score:
536
  clip_score = torch.cat(clip_scores).mean().item()
537
  results['clip_score'] = clip_score
538
+ if args.compute_image_reward:
539
+ reward = np.mean(image_rewards)
540
+ results['image_reward'] = reward
541
  results.update(vars(args))
542
  with open(dest, "w") as fd:
543
  json.dump(results, fd)
 
603
  help='whether or not compute FID')
604
  parser.add_argument('--compute_clip_score', action='store_true', default=False,
605
  help='whether or not compute CLIP score')
606
+ parser.add_argument('--compute_image_reward', action='store_true', default=False,
607
+ help='whether or not compute CLIP score')
608
+
609
  parser.add_argument('--clip_model', type=str,default="ViT-L/14")
610
  parser.add_argument('--eval_name', type=str,default="")
611
 
 
640
 
641
  parser.add_argument('--num_res_blocks', type=int, default=2,
642
  help='number of resnet blocks per scale')
643
+ parser.add_argument('--attn_resolutions', default=(16,), nargs='+', type=int,
644
  help='resolution of applying attention')
645
  parser.add_argument('--dropout', type=float, default=0.,
646
  help='drop-out rate')
train_ddgan.py CHANGED
@@ -4,14 +4,14 @@
4
  # This work is licensed under the NVIDIA Source Code License
5
  # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
6
  # ---------------------------------------------------------------
 
7
 
8
  from glob import glob
9
  import argparse
10
- import torch
11
  import numpy as np
12
-
13
  import os
14
-
15
  import torch.nn as nn
16
  import torch.nn.functional as F
17
  import torch.optim as optim
@@ -288,6 +288,15 @@ def train(rank, gpu, args):
288
  transforms.ToTensor(),
289
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
290
  ])
 
 
 
 
 
 
 
 
 
291
  shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
292
  pipeline = [ResampledShards2(shards)]
293
  pipeline.extend([
@@ -312,7 +321,7 @@ def train(rank, gpu, args):
312
  dataset,
313
  batch_size=None,
314
  shuffle=False,
315
- num_workers=8,
316
  )
317
 
318
  if args.dataset != "wds":
@@ -355,6 +364,7 @@ def train(rank, gpu, args):
355
  cond_size=text_encoder.output_size,
356
  act=nn.LeakyReLU(0.2)).to(device)
357
  elif args.discr_type == "large_attn_pool":
 
358
  netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
359
  t_emb_dim = args.t_emb_dim,
360
  cond_size=text_encoder.output_size,
@@ -362,6 +372,7 @@ def train(rank, gpu, args):
362
  act=nn.LeakyReLU(0.2)).to(device)
363
 
364
  elif args.discr_type == "large_cond_attn":
 
365
  netD = CondAttnDiscriminator(
366
  nc = 2*args.num_channels,
367
  ngf = args.ngf,
@@ -391,7 +402,7 @@ def train(rank, gpu, args):
391
  optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
392
 
393
  if args.use_ema:
394
- optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
395
 
396
  schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
397
  schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
@@ -403,12 +414,10 @@ def train(rank, gpu, args):
403
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
404
  #if args.discr_type == "projected_gan":
405
  # netD._set_static_graph()
406
-
407
 
408
- if args.grad_checkpointing:
409
- from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
410
- netG = checkpoint_wrapper(netG)
411
-
412
  exp = args.exp
413
  parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
414
 
@@ -442,8 +451,9 @@ def train(rank, gpu, args):
442
  optimizerD.load_state_dict(checkpoint['optimizerD'])
443
  schedulerD.load_state_dict(checkpoint['schedulerD'])
444
  global_step = checkpoint['global_step']
445
- print("=> loaded checkpoint (epoch {})"
446
- .format(checkpoint['epoch']))
 
447
  else:
448
  global_step, epoch, init_epoch = 0, 0, 0
449
  use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool", "projected_gan")
@@ -454,6 +464,7 @@ def train(rank, gpu, args):
454
  train_sampler.set_epoch(epoch)
455
 
456
  for iteration, (x, y) in enumerate(data_loader):
 
457
  #print(x.shape)
458
  if args.dataset != "wds":
459
  y = [str(yi) for yi in y.tolist()]
@@ -631,6 +642,8 @@ def train(rank, gpu, args):
631
  if rank == 0:
632
  print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
633
  print('Global step:', global_step)
 
 
634
  if iteration % 1000 == 0:
635
  x_t_1 = torch.randn_like(real_data)
636
  with autocast():
@@ -640,7 +653,8 @@ def train(rank, gpu, args):
640
 
641
  if args.save_content:
642
  dist.barrier()
643
- print('Saving content.')
 
644
  def to_cpu(d):
645
  for k, v in d.items():
646
  d[k] = v.cpu()
@@ -677,6 +691,9 @@ def train(rank, gpu, args):
677
  'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
678
  torch.save(content, os.path.join(exp_path, 'content.pth'))
679
  torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
 
 
 
680
  if args.use_ema:
681
  optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
682
  torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
@@ -685,40 +702,8 @@ def train(rank, gpu, args):
685
 
686
 
687
  if not args.no_lr_decay:
688
-
689
  schedulerG.step()
690
  schedulerD.step()
691
- """
692
- if rank == 0:
693
- if epoch % 10 == 0:
694
- torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
695
-
696
- x_t_1 = torch.randn_like(real_data)
697
- with autocast():
698
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
699
- torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
700
-
701
- if args.save_content:
702
- if epoch % args.save_content_every == 0:
703
- print('Saving content.')
704
- content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
705
- 'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
706
- 'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
707
- 'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
708
-
709
- torch.save(content, os.path.join(exp_path, 'content.pth'))
710
- torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
711
-
712
- if epoch % args.save_ckpt_every == 0:
713
- if args.use_ema:
714
- optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
715
-
716
- torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
717
- if args.use_ema:
718
- optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
719
- dist.barrier()
720
- """
721
-
722
 
723
  def init_processes(rank, size, fn, args):
724
  """ Initialize the distributed environment. """
@@ -748,12 +733,12 @@ if __name__ == '__main__':
748
  help='seed used for initialization')
749
 
750
  parser.add_argument('--resume', action='store_true',default=False)
751
- parser.add_argument('--masked_mean', action='store_true',default=False)
752
- parser.add_argument('--mismatch_loss', action='store_true',default=False)
753
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
754
- parser.add_argument('--cross_attention', action='store_true',default=False)
755
- parser.add_argument('--fsdp', action='store_true',default=False)
756
- parser.add_argument('--grad_checkpointing', action='store_true',default=False)
757
 
758
  parser.add_argument('--image_size', type=int, default=32,
759
  help='size of image')
@@ -767,9 +752,8 @@ if __name__ == '__main__':
767
  parser.add_argument('--beta_max', type=float, default=20.,
768
  help='beta_max for diffusion')
769
  parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
770
-
771
  parser.add_argument('--num_channels_dae', type=int, default=128,
772
- help='number of initial channels in denosing model')
773
  parser.add_argument('--n_mlp', type=int, default=3,
774
  help='number of mlp layers for z')
775
  parser.add_argument('--ch_mult', nargs='+', type=int,
@@ -825,7 +809,7 @@ if __name__ == '__main__':
825
  parser.add_argument('--beta2', type=float, default=0.9,
826
  help='beta2 for adam')
827
  parser.add_argument('--no_lr_decay',action='store_true', default=False)
828
- parser.add_argument('--grad_penalty_cond', action='store_true',default=False)
829
 
830
  parser.add_argument('--use_ema', action='store_true', default=False,
831
  help='use EMA or not')
 
4
  # This work is licensed under the NVIDIA Source Code License
5
  # for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
6
  # ---------------------------------------------------------------
7
+ import torch
8
 
9
  from glob import glob
10
  import argparse
 
11
  import numpy as np
12
+ import json
13
  import os
14
+ import time
15
  import torch.nn as nn
16
  import torch.nn.functional as F
17
  import torch.optim as optim
 
288
  transforms.ToTensor(),
289
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
290
  ])
291
+ elif args.preprocessing == "simple_random_crop_v2":
292
+ train_transform = transforms.Compose([
293
+ transforms.Resize(args.image_size),
294
+ transforms.RandomCrop(args.image_size, interpolation=3),
295
+ transforms.ToTensor(),
296
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
297
+ ])
298
+ else:
299
+ raise ValueError(args.preprocessing)
300
  shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
301
  pipeline = [ResampledShards2(shards)]
302
  pipeline.extend([
 
321
  dataset,
322
  batch_size=None,
323
  shuffle=False,
324
+ num_workers=1,
325
  )
326
 
327
  if args.dataset != "wds":
 
364
  cond_size=text_encoder.output_size,
365
  act=nn.LeakyReLU(0.2)).to(device)
366
  elif args.discr_type == "large_attn_pool":
367
+ # Discriminator with Attention Pool based discriminator for text conditioning
368
  netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
369
  t_emb_dim = args.t_emb_dim,
370
  cond_size=text_encoder.output_size,
 
372
  act=nn.LeakyReLU(0.2)).to(device)
373
 
374
  elif args.discr_type == "large_cond_attn":
375
+ # Discriminator with Cross-Attention based discriminator for text conditioning
376
  netD = CondAttnDiscriminator(
377
  nc = 2*args.num_channels,
378
  ngf = args.ngf,
 
402
  optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
403
 
404
  if args.use_ema:
405
+ optimizerG = EMA(optimizerG, ema_decay=args.ema_decay, memory_efficient=args.grad_checkpointing)
406
 
407
  schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
408
  schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
 
414
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
415
  #if args.discr_type == "projected_gan":
416
  # netD._set_static_graph()
 
417
 
418
+ #if args.grad_checkpointing:
419
+ #from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
420
+ #netG = checkpoint_wrapper(netG)
 
421
  exp = args.exp
422
  parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
423
 
 
451
  optimizerD.load_state_dict(checkpoint['optimizerD'])
452
  schedulerD.load_state_dict(checkpoint['schedulerD'])
453
  global_step = checkpoint['global_step']
454
+ if rank == 0:
455
+ print("=> loaded checkpoint (epoch {})"
456
+ .format(checkpoint['epoch']))
457
  else:
458
  global_step, epoch, init_epoch = 0, 0, 0
459
  use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool", "projected_gan")
 
464
  train_sampler.set_epoch(epoch)
465
 
466
  for iteration, (x, y) in enumerate(data_loader):
467
+ t0 = time.time()
468
  #print(x.shape)
469
  if args.dataset != "wds":
470
  y = [str(yi) for yi in y.tolist()]
 
642
  if rank == 0:
643
  print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
644
  print('Global step:', global_step)
645
+ dt = time.time() - t0
646
+ print('Time per iteration: ', dt)
647
  if iteration % 1000 == 0:
648
  x_t_1 = torch.randn_like(real_data)
649
  with autocast():
 
653
 
654
  if args.save_content:
655
  dist.barrier()
656
+ if rank == 0:
657
+ print('Saving content.')
658
  def to_cpu(d):
659
  for k, v in d.items():
660
  d[k] = v.cpu()
 
691
  'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
692
  torch.save(content, os.path.join(exp_path, 'content.pth'))
693
  torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
694
+ state_content = {'epoch': epoch + 1, 'global_step': global_step}
695
+ with open(os.path.join(exp_path, 'netG_{}.json'.format(epoch)), "w") as fd:
696
+ fd.write(json.dumps(state_content))
697
  if args.use_ema:
698
  optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
699
  torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
 
702
 
703
 
704
  if not args.no_lr_decay:
 
705
  schedulerG.step()
706
  schedulerD.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
 
708
  def init_processes(rank, size, fn, args):
709
  """ Initialize the distributed environment. """
 
733
  help='seed used for initialization')
734
 
735
  parser.add_argument('--resume', action='store_true',default=False)
736
+ parser.add_argument('--masked_mean', action='store_true',default=False, help="use masked mean to pool from t5-based text encoder")
737
+ parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
738
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
739
+ parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
740
+ parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
741
+ parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
742
 
743
  parser.add_argument('--image_size', type=int, default=32,
744
  help='size of image')
 
752
  parser.add_argument('--beta_max', type=float, default=20.,
753
  help='beta_max for diffusion')
754
  parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
 
755
  parser.add_argument('--num_channels_dae', type=int, default=128,
756
+ help='number of initial channels in denosing model generator')
757
  parser.add_argument('--n_mlp', type=int, default=3,
758
  help='number of mlp layers for z')
759
  parser.add_argument('--ch_mult', nargs='+', type=int,
 
809
  parser.add_argument('--beta2', type=float, default=0.9,
810
  help='beta2 for adam')
811
  parser.add_argument('--no_lr_decay',action='store_true', default=False)
812
+ parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad penalty")
813
 
814
  parser.add_argument('--use_ema', action='store_true', default=False,
815
  help='use EMA or not')