Mehdi Cherti commited on
Commit
8ab4de9
1 Parent(s): 3dcdf92

add basic cross attention + global attention block

Browse files
score_sde/models/layers.py CHANGED
@@ -583,7 +583,7 @@ class Identity(nn.Module):
583
  def forward(self, x, *args, **kwargs):
584
  return x
585
 
586
-
587
  class CrossAttention(nn.Module):
588
  def __init__(
589
  self,
 
583
  def forward(self, x, *args, **kwargs):
584
  return x
585
 
586
+
587
  class CrossAttention(nn.Module):
588
  def __init__(
589
  self,
score_sde/models/layerspp.py CHANGED
@@ -123,6 +123,34 @@ class AttnBlockpp(nn.Module):
123
  else:
124
  return (x + h) / np.sqrt(2.)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  class Upsample(nn.Module):
128
  def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
 
123
  else:
124
  return (x + h) / np.sqrt(2.)
125
 
126
+ class AttnBlockppRaw(nn.Module):
127
+ """Channel-wise self-attention block. Modified from DDPM."""
128
+
129
+ def __init__(self, channels, skip_rescale=False, init_scale=0.):
130
+ super().__init__()
131
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
132
+ eps=1e-6)
133
+ self.NIN_0 = NIN(channels, channels)
134
+ self.NIN_1 = NIN(channels, channels)
135
+ self.NIN_2 = NIN(channels, channels)
136
+ self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
137
+ self.skip_rescale = skip_rescale
138
+
139
+ def forward(self, x):
140
+ B, C, H, W = x.shape
141
+ h = self.GroupNorm_0(x)
142
+ q = self.NIN_0(h)
143
+ k = self.NIN_1(h)
144
+ v = self.NIN_2(h)
145
+
146
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
147
+ w = torch.reshape(w, (B, H, W, H * W))
148
+ w = F.softmax(w, dim=-1)
149
+ w = torch.reshape(w, (B, H, W, H, W))
150
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
151
+ h = self.NIN_3(h)
152
+ return h
153
+
154
 
155
  class Upsample(nn.Module):
156
  def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
score_sde/models/ncsnpp_generator_adagn.py CHANGED
@@ -53,6 +53,36 @@ get_act = layers.get_act
53
  default_initializer = layers.default_init
54
  dense = dense_layer.dense
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class PixelNorm(nn.Module):
57
  def __init__(self):
58
  super().__init__()
@@ -68,6 +98,7 @@ class NCSNpp(nn.Module):
68
  def __init__(self, config):
69
  super().__init__()
70
  self.config = config
 
71
  self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
72
  self.not_use_tanh = config.not_use_tanh
73
  self.act = act = nn.SiLU()
@@ -124,7 +155,14 @@ class NCSNpp(nn.Module):
124
  modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
125
  nn.init.zeros_(modules[-1].bias)
126
  if config.cross_attention:
127
- AttnBlock = functools.partial(layers.CondAttnBlock, context_dim=config.cond_size)
 
 
 
 
 
 
 
128
  else:
129
  AttnBlock = functools.partial(layerspp.AttnBlockpp,
130
  init_scale=init_scale,
@@ -342,7 +380,7 @@ class NCSNpp(nn.Module):
342
  h = modules[m_idx](hs[-1], temb, zemb)
343
  m_idx += 1
344
  if h.shape[-1] in self.attn_resolutions:
345
- if type(modules[m_idx]) == layers.CondAttnBlock:
346
  h = modules[m_idx](h, cond, cond_mask)
347
  else:
348
  h = modules[m_idx](h)
@@ -377,7 +415,7 @@ class NCSNpp(nn.Module):
377
  h = hs[-1]
378
  h = modules[m_idx](h, temb, zemb)
379
  m_idx += 1
380
- if type(modules[m_idx]) == layers.CondAttnBlock:
381
  h = modules[m_idx](h, cond, cond_mask)
382
  else:
383
  h = modules[m_idx](h)
@@ -394,7 +432,7 @@ class NCSNpp(nn.Module):
394
  m_idx += 1
395
 
396
  if h.shape[-1] in self.attn_resolutions:
397
- if type(modules[m_idx]) == layers.CondAttnBlock:
398
  h = modules[m_idx](h, cond, cond_mask)
399
  else:
400
  h = modules[m_idx](h)
 
53
  default_initializer = layers.default_init
54
  dense = dense_layer.dense
55
 
56
+ class CrossAndGlobalAttnBlock(nn.Module):
57
+ """Channel-wise self-attention block."""
58
+ def __init__(self, channels, *, context_dim=None, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False):
59
+ super().__init__()
60
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
61
+ self.ca = layers.CrossAttention(
62
+ channels,
63
+ context_dim=context_dim,
64
+ dim_head=dim_head,
65
+ heads=heads,
66
+ norm_context=norm_context,
67
+ cosine_sim_attn=cosine_sim_attn,
68
+ )
69
+ self.attn = layerspp.AttnBlockppRaw(channels)
70
+
71
+ def forward(self, x, cond, mask=None):
72
+ B, C, H, W = x.shape
73
+ h = self.GroupNorm_0(x)
74
+ h = h.view(B, C, H*W)
75
+ h = h.permute(0,2,1)
76
+ h = h.contiguous()
77
+ h_new = self.ca(h, cond, mask=mask)
78
+ h_new = h_new.permute(0,2,1)
79
+ h_new = h_new.contiguous()
80
+ h_new = h_new.view(B, C, H, W)
81
+
82
+ h_global = self.attn(x)
83
+ h = h_new + h_global
84
+ return x + h
85
+
86
  class PixelNorm(nn.Module):
87
  def __init__(self):
88
  super().__init__()
 
98
  def __init__(self, config):
99
  super().__init__()
100
  self.config = config
101
+ self.cross_attention_block = config.cross_attention_block
102
  self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
103
  self.not_use_tanh = config.not_use_tanh
104
  self.act = act = nn.SiLU()
 
155
  modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
156
  nn.init.zeros_(modules[-1].bias)
157
  if config.cross_attention:
158
+
159
+ #block_name = config.cross_attention_block if hasattr(config, "cross_attention_block") else "basic"
160
+ block_name = config.cross_attention_block
161
+ if block_name == "basic":
162
+ AttnBlock = functools.partial(layers.CondAttnBlock, context_dim=config.cond_size)
163
+ elif block_name == "cross_and_global_attention":
164
+ AttnBlock = functools.partial(CrossAndGlobalAttnBlock, context_dim=config.cond_size)
165
+ print(AttnBlock)
166
  else:
167
  AttnBlock = functools.partial(layerspp.AttnBlockpp,
168
  init_scale=init_scale,
 
380
  h = modules[m_idx](hs[-1], temb, zemb)
381
  m_idx += 1
382
  if h.shape[-1] in self.attn_resolutions:
383
+ if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
384
  h = modules[m_idx](h, cond, cond_mask)
385
  else:
386
  h = modules[m_idx](h)
 
415
  h = hs[-1]
416
  h = modules[m_idx](h, temb, zemb)
417
  m_idx += 1
418
+ if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
419
  h = modules[m_idx](h, cond, cond_mask)
420
  else:
421
  h = modules[m_idx](h)
 
432
  m_idx += 1
433
 
434
  if h.shape[-1] in self.attn_resolutions:
435
+ if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
436
  h = modules[m_idx](h, cond, cond_mask)
437
  else:
438
  h = modules[m_idx](h)
train_ddgan.py CHANGED
@@ -385,9 +385,10 @@ def train(rank, gpu, args):
385
  backbone_kwargs={"cond_size": text_encoder.output_size}
386
  )
387
  netD = netD.to(device)
388
-
389
- broadcast_params(netG.parameters())
390
- broadcast_params(netD.parameters())
 
391
 
392
  if args.fsdp:
393
  from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
@@ -410,8 +411,9 @@ def train(rank, gpu, args):
410
  if args.fsdp:
411
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
412
  else:
413
- netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
414
- netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
 
415
  #if args.discr_type == "projected_gan":
416
  # netD._set_static_graph()
417
 
@@ -652,7 +654,8 @@ def train(rank, gpu, args):
652
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
653
 
654
  if args.save_content:
655
- dist.barrier()
 
656
  if rank == 0:
657
  print('Saving content.')
658
  def to_cpu(d):
@@ -709,20 +712,26 @@ def init_processes(rank, size, fn, args):
709
  """ Initialize the distributed environment. """
710
 
711
  import os
712
-
713
- args.rank = int(os.environ['SLURM_PROCID'])
714
- args.world_size = int(os.getenv("SLURM_NTASKS"))
715
- args.local_rank = int(os.environ['SLURM_LOCALID'])
716
- print(args.rank, args.world_size)
717
- args.master_address = os.getenv("SLURM_LAUNCH_NODE_IPADDR")
718
- os.environ['MASTER_ADDR'] = args.master_address
719
- os.environ['MASTER_PORT'] = "12345"
720
- torch.cuda.set_device(args.local_rank)
721
- gpu = args.local_rank
722
- dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=args.world_size)
723
- fn(rank, gpu, args)
724
- dist.barrier()
725
- cleanup()
 
 
 
 
 
 
726
 
727
  def cleanup():
728
  dist.destroy_process_group()
@@ -737,6 +746,8 @@ if __name__ == '__main__':
737
  parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
738
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
739
  parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
 
 
740
  parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
741
  parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
742
 
@@ -809,7 +820,7 @@ if __name__ == '__main__':
809
  parser.add_argument('--beta2', type=float, default=0.9,
810
  help='beta2 for adam')
811
  parser.add_argument('--no_lr_decay',action='store_true', default=False)
812
- parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad penalty")
813
 
814
  parser.add_argument('--use_ema', action='store_true', default=False,
815
  help='use EMA or not')
@@ -828,6 +839,7 @@ if __name__ == '__main__':
828
  parser.add_argument('--precision', type=str, default="fp32")
829
 
830
  ###ddp
 
831
  parser.add_argument('--num_proc_node', type=int, default=1,
832
  help='The number of nodes in multi node env.')
833
  parser.add_argument('--num_process_per_node', type=int, default=1,
@@ -840,8 +852,10 @@ if __name__ == '__main__':
840
  help='address for master')
841
 
842
  args = parser.parse_args()
843
- # args.world_size = args.num_proc_node * args.num_process_per_node
844
- args.world_size = int(os.getenv("SLURM_NTASKS"))
845
- args.rank = int(os.environ['SLURM_PROCID'])
846
- # size = args.num_process_per_node
 
 
847
  init_processes(args.rank, args.world_size, train, args)
 
385
  backbone_kwargs={"cond_size": text_encoder.output_size}
386
  )
387
  netD = netD.to(device)
388
+
389
+ if args.world_size > 1:
390
+ broadcast_params(netG.parameters())
391
+ broadcast_params(netD.parameters())
392
 
393
  if args.fsdp:
394
  from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
 
411
  if args.fsdp:
412
  netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
413
  else:
414
+ if args.world_size > 1:
415
+ netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
416
+ netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
417
  #if args.discr_type == "projected_gan":
418
  # netD._set_static_graph()
419
 
 
654
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
655
 
656
  if args.save_content:
657
+ if args.world_size > 1:
658
+ dist.barrier()
659
  if rank == 0:
660
  print('Saving content.')
661
  def to_cpu(d):
 
712
  """ Initialize the distributed environment. """
713
 
714
  import os
715
+
716
+ if size == 1:
717
+ args.rank = 0
718
+ args.world_size = 1
719
+ args.local_rank = 0
720
+ fn(rank,args.local_rank, args)
721
+ else:
722
+ args.rank = int(os.environ['SLURM_PROCID'])
723
+ args.world_size = int(os.getenv("SLURM_NTASKS"))
724
+ args.local_rank = int(os.environ['SLURM_LOCALID'])
725
+ print(args.rank, args.world_size)
726
+ args.master_address = os.getenv("SLURM_LAUNCH_NODE_IPADDR")
727
+ os.environ['MASTER_ADDR'] = args.master_address
728
+ os.environ['MASTER_PORT'] = "12345"
729
+ torch.cuda.set_device(args.local_rank)
730
+ gpu = args.local_rank
731
+ dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=args.world_size)
732
+ fn(rank, gpu, args)
733
+ dist.barrier()
734
+ cleanup()
735
 
736
  def cleanup():
737
  dist.destroy_process_group()
 
746
  parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
747
  parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
748
  parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
749
+ parser.add_argument('--cross_attention_block', default="basic", help="cross attention block type")
750
+
751
  parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
752
  parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
753
 
 
820
  parser.add_argument('--beta2', type=float, default=0.9,
821
  help='beta2 for adam')
822
  parser.add_argument('--no_lr_decay',action='store_true', default=False)
823
+ parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad")
824
 
825
  parser.add_argument('--use_ema', action='store_true', default=False,
826
  help='use EMA or not')
 
839
  parser.add_argument('--precision', type=str, default="fp32")
840
 
841
  ###ddp
842
+
843
  parser.add_argument('--num_proc_node', type=int, default=1,
844
  help='The number of nodes in multi node env.')
845
  parser.add_argument('--num_process_per_node', type=int, default=1,
 
852
  help='address for master')
853
 
854
  args = parser.parse_args()
855
+ if 'SLURM_NTASKS' in os.environ:
856
+ args.world_size = int(os.getenv("SLURM_NTASKS"))
857
+ args.rank = int(os.environ['SLURM_PROCID'])
858
+ else:
859
+ args.world_size = 1
860
+ args.rank = 0
861
  init_processes(args.rank, args.world_size, train, args)