Mehdi Cherti commited on
Commit
23d6920
1 Parent(s): bc53ac3

- support different preprocessing

Browse files

- add shuffling to wds
- support cross attention for discr in training
-

Files changed (2) hide show
  1. scripts/init.sh +4 -1
  2. train_ddgan.py +59 -59
scripts/init.sh CHANGED
@@ -8,7 +8,10 @@ if [[ "$machine" == jurecadc ]]; then
8
  ml OpenMPI/4.1.2
9
  ml CUDA/11.5
10
  ml cuDNN/8.3.1.22-CUDA-11.5
11
- ml NCCL/2.12.7-1-CUDA-11.5
 
 
 
12
  ml PyTorch/1.11-CUDA-11.5
13
  ml Horovod/0.24
14
  ml torchvision/0.12.0
 
8
  ml OpenMPI/4.1.2
9
  ml CUDA/11.5
10
  ml cuDNN/8.3.1.22-CUDA-11.5
11
+
12
+ ml NCCL/2.11.4-CUDA-11.5
13
+ #ml NCCL/2.12.7-1-CUDA-11.5
14
+
15
  ml PyTorch/1.11-CUDA-11.5
16
  ml Horovod/0.24
17
  ml torchvision/0.12.0
train_ddgan.py CHANGED
@@ -195,14 +195,14 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
195
  from utils import ResampledShards2
196
 
197
  def train(rank, gpu, args):
198
- from score_sde.models.discriminator import Discriminator_small, Discriminator_large
199
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
200
  from EMA import EMA
201
 
202
  torch.manual_seed(args.seed + rank)
203
  torch.cuda.manual_seed(args.seed + rank)
204
  torch.cuda.manual_seed_all(args.seed + rank)
205
- device = torch.device('cuda:{}'.format(gpu))
206
 
207
  batch_size = args.batch_size
208
 
@@ -254,19 +254,28 @@ def train(rank, gpu, args):
254
  dataset = ImageFolder(root=args.dataset_root, transform=train_transform)
255
  elif args.dataset == 'wds':
256
  import webdataset as wds
257
- train_transform = transforms.Compose([
258
- transforms.Resize(args.image_size),
259
- transforms.CenterCrop(args.image_size),
260
- # transforms.RandomHorizontalFlip(),
261
- transforms.ToTensor(),
262
- transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
 
 
 
 
 
 
263
  ])
264
- # pipeline = [wds.SimpleShardList(args.dataset_root)]
265
  pipeline = [ResampledShards2(args.dataset_root)]
266
  pipeline.extend([
267
  wds.split_by_node,
268
  wds.split_by_worker,
269
  wds.tarfile_to_samples(handler=log_and_continue),
 
 
 
 
270
  ])
271
  pipeline.extend([
272
  wds.decode("pilrgb", handler=log_and_continue),
@@ -284,16 +293,20 @@ def train(rank, gpu, args):
284
  )
285
 
286
  if args.dataset != "wds":
287
- train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,
288
- num_replicas=args.world_size,
289
- rank=rank)
290
- data_loader = torch.utils.data.DataLoader(dataset,
291
- batch_size=batch_size,
292
- shuffle=False,
293
- num_workers=4,
294
- drop_last=True,
295
- pin_memory=True,
296
- sampler=train_sampler,)
 
 
 
 
297
  text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
298
  args.cond_size = text_encoder.output_size
299
  netG = NCSNpp(args).to(device)
@@ -302,18 +315,30 @@ def train(rank, gpu, args):
302
  nb_params += param.flatten().shape[0]
303
  print("Number of generator parameters:", nb_params)
304
 
305
-
306
- if args.dataset == 'cifar10' or args.dataset == 'stackmnist':
307
  netD = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,
308
  t_emb_dim = args.t_emb_dim,
309
  cond_size=text_encoder.output_size,
310
  act=nn.LeakyReLU(0.2)).to(device)
311
- else:
 
 
 
 
 
 
312
  netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
313
- t_emb_dim = args.t_emb_dim,
314
  cond_size=text_encoder.output_size,
315
- act=nn.LeakyReLU(0.2)).to(device)
316
-
 
 
 
 
 
 
 
317
  broadcast_params(netG.parameters())
318
  broadcast_params(netD.parameters())
319
 
@@ -326,13 +351,9 @@ def train(rank, gpu, args):
326
  schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
327
  schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
328
 
329
-
330
-
331
- #ddp
332
  netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
333
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
334
 
335
-
336
  exp = args.exp
337
  parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
338
 
@@ -343,7 +364,6 @@ def train(rank, gpu, args):
343
  copy_source(__file__, exp_path)
344
  shutil.copytree('score_sde/models', os.path.join(exp_path, 'score_sde/models'))
345
 
346
-
347
  coeff = Diffusion_Coefficients(args, device)
348
  pos_coeff = Posterior_Coefficients(args, device)
349
  T = get_time_schedule(args, device)
@@ -368,7 +388,6 @@ def train(rank, gpu, args):
368
  else:
369
  global_step, epoch, init_epoch = 0, 0, 0
370
 
371
-
372
  for epoch in range(init_epoch, args.num_epoch+1):
373
  if args.dataset == "wds":
374
  os.environ["WDS_EPOCH"] = str(epoch)
@@ -388,7 +407,6 @@ def train(rank, gpu, args):
388
 
389
  for p in netD.parameters():
390
  p.requires_grad = True
391
-
392
 
393
  netD.zero_grad()
394
 
@@ -401,9 +419,10 @@ def train(rank, gpu, args):
401
  x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
402
  x_t.requires_grad = True
403
 
404
-
 
405
  # train with real
406
- D_real = netD(x_t, t, x_tp1.detach(), cond=cond_pooled).view(-1)
407
 
408
  errD_real = F.softplus(-D_real)
409
  errD_real = errD_real.mean()
@@ -442,7 +461,7 @@ def train(rank, gpu, args):
442
  x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
443
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
444
 
445
- output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_pooled).view(-1)
446
 
447
 
448
  errD_fake = F.softplus(output)
@@ -474,7 +493,7 @@ def train(rank, gpu, args):
474
  x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
475
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
476
 
477
- output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_pooled).view(-1)
478
 
479
 
480
  errG = F.softplus(-output)
@@ -658,7 +677,9 @@ if __name__ == '__main__':
658
  parser.add_argument('--save_content', action='store_true',default=False)
659
  parser.add_argument('--save_content_every', type=int, default=50, help='save content for resuming every x epochs')
660
  parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
661
-
 
 
662
  ###ddp
663
  parser.add_argument('--num_proc_node', type=int, default=1,
664
  help='The number of nodes in multi node env.')
@@ -671,30 +692,9 @@ if __name__ == '__main__':
671
  parser.add_argument('--master_address', type=str, default='127.0.0.1',
672
  help='address for master')
673
 
674
-
675
  args = parser.parse_args()
676
  # args.world_size = args.num_proc_node * args.num_process_per_node
677
  args.world_size = int(os.getenv("SLURM_NTASKS"))
678
  args.rank = int(os.environ['SLURM_PROCID'])
679
  # size = args.num_process_per_node
680
- init_processes(args.rank, args.world_size, train, args)
681
- # if size > 1:
682
- # processes = []
683
- # for rank in range(size):
684
- # args.local_rank = rank
685
- # global_rank = rank + args.node_rank * args.num_process_per_node
686
- # global_size = args.num_proc_node * args.num_process_per_node
687
- # args.global_rank = global_rank
688
- # print('Node rank %d, local proc %d, global proc %d' % (args.node_rank, rank, global_rank))
689
- # p = Process(target=init_processes, args=(global_rank, global_size, train, args))
690
- # p.start()
691
- # processes.append(p)
692
-
693
- # for p in processes:
694
- # p.join()
695
- # else:
696
- # print('starting in debug mode')
697
-
698
- # init_processes(0, size, train, args)
699
-
700
-
 
195
  from utils import ResampledShards2
196
 
197
  def train(rank, gpu, args):
198
+ from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
199
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
200
  from EMA import EMA
201
 
202
  torch.manual_seed(args.seed + rank)
203
  torch.cuda.manual_seed(args.seed + rank)
204
  torch.cuda.manual_seed_all(args.seed + rank)
205
+ device = "cuda"
206
 
207
  batch_size = args.batch_size
208
 
 
254
  dataset = ImageFolder(root=args.dataset_root, transform=train_transform)
255
  elif args.dataset == 'wds':
256
  import webdataset as wds
257
+ if args.preprocessing == "resize":
258
+ train_transform = transforms.Compose([
259
+ transforms.Resize(args.image_size),
260
+ transforms.CenterCrop(args.image_size),
261
+ transforms.ToTensor(),
262
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
263
+ ])
264
+ elif args.preprocessing == "random_resized_crop_v1":
265
+ train_transform = transforms.Compose([
266
+ transforms.RandomResizedCrop(256, scale=(0.95, 1.0), interpolation=3),
267
+ transforms.ToTensor(),
268
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
269
  ])
 
270
  pipeline = [ResampledShards2(args.dataset_root)]
271
  pipeline.extend([
272
  wds.split_by_node,
273
  wds.split_by_worker,
274
  wds.tarfile_to_samples(handler=log_and_continue),
275
+ wds.shuffle(
276
+ bufsize=5000,
277
+ initial=1000,
278
+ ),
279
  ])
280
  pipeline.extend([
281
  wds.decode("pilrgb", handler=log_and_continue),
 
293
  )
294
 
295
  if args.dataset != "wds":
296
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
297
+ dataset,
298
+ num_replicas=args.world_size,
299
+ rank=rank
300
+ )
301
+ data_loader = torch.utils.data.DataLoader(
302
+ dataset,
303
+ batch_size=batch_size,
304
+ shuffle=False,
305
+ num_workers=4,
306
+ drop_last=True,
307
+ pin_memory=True,
308
+ sampler=train_sampler,
309
+ )
310
  text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
311
  args.cond_size = text_encoder.output_size
312
  netG = NCSNpp(args).to(device)
 
315
  nb_params += param.flatten().shape[0]
316
  print("Number of generator parameters:", nb_params)
317
 
318
+ if args.discr_type == "small":
 
319
  netD = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,
320
  t_emb_dim = args.t_emb_dim,
321
  cond_size=text_encoder.output_size,
322
  act=nn.LeakyReLU(0.2)).to(device)
323
+ elif args.discr_type == "small_cond_attn":
324
+ netD = SmallCondAttnDiscriminator(nc = 2*args.num_channels, ngf = args.ngf,
325
+ t_emb_dim = args.t_emb_dim,
326
+ cond_size=text_encoder.output_size,
327
+ act=nn.LeakyReLU(0.2)).to(device)
328
+
329
+ elif args.discr_type == "large":
330
  netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
331
+ t_emb_dim = args.t_emb_dim,
332
  cond_size=text_encoder.output_size,
333
+ act=nn.LeakyReLU(0.2)).to(device)
334
+ elif args.discr_type == "large_cond_attn":
335
+ netD = CondAttnDiscriminator(
336
+ nc = 2*args.num_channels,
337
+ ngf = args.ngf,
338
+ t_emb_dim = args.t_emb_dim,
339
+ cond_size=text_encoder.output_size,
340
+ act=nn.LeakyReLU(0.2)).to(device)
341
+
342
  broadcast_params(netG.parameters())
343
  broadcast_params(netD.parameters())
344
 
 
351
  schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
352
  schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
353
 
 
 
 
354
  netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
355
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
356
 
 
357
  exp = args.exp
358
  parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
359
 
 
364
  copy_source(__file__, exp_path)
365
  shutil.copytree('score_sde/models', os.path.join(exp_path, 'score_sde/models'))
366
 
 
367
  coeff = Diffusion_Coefficients(args, device)
368
  pos_coeff = Posterior_Coefficients(args, device)
369
  T = get_time_schedule(args, device)
 
388
  else:
389
  global_step, epoch, init_epoch = 0, 0, 0
390
 
 
391
  for epoch in range(init_epoch, args.num_epoch+1):
392
  if args.dataset == "wds":
393
  os.environ["WDS_EPOCH"] = str(epoch)
 
407
 
408
  for p in netD.parameters():
409
  p.requires_grad = True
 
410
 
411
  netD.zero_grad()
412
 
 
419
  x_t, x_tp1 = q_sample_pairs(coeff, real_data, t)
420
  x_t.requires_grad = True
421
 
422
+ cond_for_discr = (cond_pooled, cond, cond_mask) if args.discr_type in ("large_cond_attn", "small_cond_attn") else cond_pooled
423
+
424
  # train with real
425
+ D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
426
 
427
  errD_real = F.softplus(-D_real)
428
  errD_real = errD_real.mean()
 
461
  x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
462
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
463
 
464
+ output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
465
 
466
 
467
  errD_fake = F.softplus(output)
 
493
  x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
494
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
495
 
496
+ output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
497
 
498
 
499
  errG = F.softplus(-output)
 
677
  parser.add_argument('--save_content', action='store_true',default=False)
678
  parser.add_argument('--save_content_every', type=int, default=50, help='save content for resuming every x epochs')
679
  parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
680
+ parser.add_argument('--discr_type', type=str, default="large")
681
+ parser.add_argument('--preprocessing', type=str, default="resize")
682
+
683
  ###ddp
684
  parser.add_argument('--num_proc_node', type=int, default=1,
685
  help='The number of nodes in multi node env.')
 
692
  parser.add_argument('--master_address', type=str, default='127.0.0.1',
693
  help='address for master')
694
 
 
695
  args = parser.parse_args()
696
  # args.world_size = args.num_proc_node * args.num_process_per_node
697
  args.world_size = int(os.getenv("SLURM_NTASKS"))
698
  args.rank = int(os.environ['SLURM_PROCID'])
699
  # size = args.num_process_per_node
700
+ init_processes(args.rank, args.world_size, train, args)