ashawkey commited on
Commit
503bb19
·
verified ·
1 Parent(s): 78eec61

Update unet/mv_unet.py

Browse files
Files changed (1) hide show
  1. unet/mv_unet.py +0 -84
unet/mv_unet.py CHANGED
@@ -39,55 +39,6 @@ def get_camera(
39
  return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
 
41
 
42
- def checkpoint(func, inputs, params, flag):
43
- """
44
- Evaluate a function without caching intermediate activations, allowing for
45
- reduced memory at the expense of extra compute in the backward pass.
46
- :param func: the function to evaluate.
47
- :param inputs: the argument sequence to pass to `func`.
48
- :param params: a sequence of parameters `func` depends on but does not
49
- explicitly take as arguments.
50
- :param flag: if False, disable gradient checkpointing.
51
- """
52
- if flag:
53
- args = tuple(inputs) + tuple(params)
54
- return CheckpointFunction.apply(func, len(inputs), *args)
55
- else:
56
- return func(*inputs)
57
-
58
-
59
- class CheckpointFunction(torch.autograd.Function):
60
- @staticmethod
61
- def forward(ctx, run_function, length, *args):
62
- ctx.run_function = run_function
63
- ctx.input_tensors = list(args[:length])
64
- ctx.input_params = list(args[length:])
65
-
66
- with torch.no_grad():
67
- output_tensors = ctx.run_function(*ctx.input_tensors)
68
- return output_tensors
69
-
70
- @staticmethod
71
- def backward(ctx, *output_grads):
72
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
73
- with torch.enable_grad():
74
- # Fixes a bug where the first op in run_function modifies the
75
- # Tensor storage in place, which is not allowed for detach()'d
76
- # Tensors.
77
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
78
- output_tensors = ctx.run_function(*shallow_copies)
79
- input_grads = torch.autograd.grad(
80
- output_tensors,
81
- ctx.input_tensors + ctx.input_params,
82
- output_grads,
83
- allow_unused=True,
84
- )
85
- del ctx.input_tensors
86
- del ctx.input_params
87
- del output_tensors
88
- return (None, None) + input_grads
89
-
90
-
91
  def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
92
  """
93
  Create sinusoidal timestep embeddings.
@@ -286,7 +237,6 @@ class BasicTransformerBlock3D(nn.Module):
286
  context_dim,
287
  dropout=0.0,
288
  gated_ff=True,
289
- checkpoint=True,
290
  ip_dim=0,
291
  ip_weight=1,
292
  ):
@@ -313,14 +263,8 @@ class BasicTransformerBlock3D(nn.Module):
313
  self.norm1 = nn.LayerNorm(dim)
314
  self.norm2 = nn.LayerNorm(dim)
315
  self.norm3 = nn.LayerNorm(dim)
316
- self.checkpoint = checkpoint
317
 
318
  def forward(self, x, context=None, num_frames=1):
319
- return checkpoint(
320
- self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
321
- )
322
-
323
- def _forward(self, x, context=None, num_frames=1):
324
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
325
  x = self.attn1(self.norm1(x), context=None) + x
326
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
@@ -341,7 +285,6 @@ class SpatialTransformer3D(nn.Module):
341
  dropout=0.0,
342
  ip_dim=0,
343
  ip_weight=1,
344
- use_checkpoint=True,
345
  ):
346
  super().__init__()
347
 
@@ -362,7 +305,6 @@ class SpatialTransformer3D(nn.Module):
362
  d_head,
363
  context_dim=context_dim[d],
364
  dropout=dropout,
365
- checkpoint=use_checkpoint,
366
  ip_dim=ip_dim,
367
  ip_weight=ip_weight,
368
  )
@@ -581,7 +523,6 @@ class ResBlock(nn.Module):
581
  convolution instead of a smaller 1x1 convolution to change the
582
  channels in the skip connection.
583
  :param dims: determines if the signal is 1D, 2D, or 3D.
584
- :param use_checkpoint: if True, use gradient checkpointing on this module.
585
  :param up: if True, use this block for upsampling.
586
  :param down: if True, use this block for downsampling.
587
  """
@@ -595,7 +536,6 @@ class ResBlock(nn.Module):
595
  use_conv=False,
596
  use_scale_shift_norm=False,
597
  dims=2,
598
- use_checkpoint=False,
599
  up=False,
600
  down=False,
601
  ):
@@ -605,7 +545,6 @@ class ResBlock(nn.Module):
605
  self.dropout = dropout
606
  self.out_channels = out_channels or channels
607
  self.use_conv = use_conv
608
- self.use_checkpoint = use_checkpoint
609
  self.use_scale_shift_norm = use_scale_shift_norm
610
 
611
  self.in_layers = nn.Sequential(
@@ -651,17 +590,6 @@ class ResBlock(nn.Module):
651
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
652
 
653
  def forward(self, x, emb):
654
- """
655
- Apply the block to a Tensor, conditioned on a timestep embedding.
656
- :param x: an [N x C x ...] Tensor of features.
657
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
658
- :return: an [N x C x ...] Tensor of outputs.
659
- """
660
- return checkpoint(
661
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
662
- )
663
-
664
- def _forward(self, x, emb):
665
  if self.updown:
666
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
667
  h = in_rest(x)
@@ -702,7 +630,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
702
  :param dims: determines if the signal is 1D, 2D, or 3D.
703
  :param num_classes: if specified (as an int), then this model will be
704
  class-conditional with `num_classes` classes.
705
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
706
  :param num_heads: the number of attention heads in each attention layer.
707
  :param num_heads_channels: if specified, ignore num_heads and instead use
708
  a fixed channel width per attention head.
@@ -728,7 +655,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
728
  conv_resample=True,
729
  dims=2,
730
  num_classes=None,
731
- use_checkpoint=False,
732
  num_heads=-1,
733
  num_head_channels=-1,
734
  num_heads_upsample=-1,
@@ -794,7 +720,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
794
  self.channel_mult = channel_mult
795
  self.conv_resample = conv_resample
796
  self.num_classes = num_classes
797
- self.use_checkpoint = use_checkpoint
798
  self.num_heads = num_heads
799
  self.num_head_channels = num_head_channels
800
  self.num_heads_upsample = num_heads_upsample
@@ -868,7 +793,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
868
  dropout,
869
  out_channels=mult * model_channels,
870
  dims=dims,
871
- use_checkpoint=use_checkpoint,
872
  use_scale_shift_norm=use_scale_shift_norm,
873
  )
874
  ]
@@ -888,7 +812,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
888
  dim_head,
889
  context_dim=context_dim,
890
  depth=transformer_depth,
891
- use_checkpoint=use_checkpoint,
892
  ip_dim=self.ip_dim,
893
  ip_weight=self.ip_weight,
894
  )
@@ -906,7 +829,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
906
  dropout,
907
  out_channels=out_ch,
908
  dims=dims,
909
- use_checkpoint=use_checkpoint,
910
  use_scale_shift_norm=use_scale_shift_norm,
911
  down=True,
912
  )
@@ -933,7 +855,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
933
  time_embed_dim,
934
  dropout,
935
  dims=dims,
936
- use_checkpoint=use_checkpoint,
937
  use_scale_shift_norm=use_scale_shift_norm,
938
  ),
939
  SpatialTransformer3D(
@@ -942,7 +863,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
942
  dim_head,
943
  context_dim=context_dim,
944
  depth=transformer_depth,
945
- use_checkpoint=use_checkpoint,
946
  ip_dim=self.ip_dim,
947
  ip_weight=self.ip_weight,
948
  ),
@@ -951,7 +871,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
951
  time_embed_dim,
952
  dropout,
953
  dims=dims,
954
- use_checkpoint=use_checkpoint,
955
  use_scale_shift_norm=use_scale_shift_norm,
956
  ),
957
  )
@@ -968,7 +887,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
968
  dropout,
969
  out_channels=model_channels * mult,
970
  dims=dims,
971
- use_checkpoint=use_checkpoint,
972
  use_scale_shift_norm=use_scale_shift_norm,
973
  )
974
  ]
@@ -988,7 +906,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
988
  dim_head,
989
  context_dim=context_dim,
990
  depth=transformer_depth,
991
- use_checkpoint=use_checkpoint,
992
  ip_dim=self.ip_dim,
993
  ip_weight=self.ip_weight,
994
  )
@@ -1002,7 +919,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
1002
  dropout,
1003
  out_channels=out_ch,
1004
  dims=dims,
1005
- use_checkpoint=use_checkpoint,
1006
  use_scale_shift_norm=use_scale_shift_norm,
1007
  up=True,
1008
  )
 
39
  return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
  """
44
  Create sinusoidal timestep embeddings.
 
237
  context_dim,
238
  dropout=0.0,
239
  gated_ff=True,
 
240
  ip_dim=0,
241
  ip_weight=1,
242
  ):
 
263
  self.norm1 = nn.LayerNorm(dim)
264
  self.norm2 = nn.LayerNorm(dim)
265
  self.norm3 = nn.LayerNorm(dim)
 
266
 
267
  def forward(self, x, context=None, num_frames=1):
 
 
 
 
 
268
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
  x = self.attn1(self.norm1(x), context=None) + x
270
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
 
285
  dropout=0.0,
286
  ip_dim=0,
287
  ip_weight=1,
 
288
  ):
289
  super().__init__()
290
 
 
305
  d_head,
306
  context_dim=context_dim[d],
307
  dropout=dropout,
 
308
  ip_dim=ip_dim,
309
  ip_weight=ip_weight,
310
  )
 
523
  convolution instead of a smaller 1x1 convolution to change the
524
  channels in the skip connection.
525
  :param dims: determines if the signal is 1D, 2D, or 3D.
 
526
  :param up: if True, use this block for upsampling.
527
  :param down: if True, use this block for downsampling.
528
  """
 
536
  use_conv=False,
537
  use_scale_shift_norm=False,
538
  dims=2,
 
539
  up=False,
540
  down=False,
541
  ):
 
545
  self.dropout = dropout
546
  self.out_channels = out_channels or channels
547
  self.use_conv = use_conv
 
548
  self.use_scale_shift_norm = use_scale_shift_norm
549
 
550
  self.in_layers = nn.Sequential(
 
590
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
 
592
  def forward(self, x, emb):
 
 
 
 
 
 
 
 
 
 
 
593
  if self.updown:
594
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
  h = in_rest(x)
 
630
  :param dims: determines if the signal is 1D, 2D, or 3D.
631
  :param num_classes: if specified (as an int), then this model will be
632
  class-conditional with `num_classes` classes.
 
633
  :param num_heads: the number of attention heads in each attention layer.
634
  :param num_heads_channels: if specified, ignore num_heads and instead use
635
  a fixed channel width per attention head.
 
655
  conv_resample=True,
656
  dims=2,
657
  num_classes=None,
 
658
  num_heads=-1,
659
  num_head_channels=-1,
660
  num_heads_upsample=-1,
 
720
  self.channel_mult = channel_mult
721
  self.conv_resample = conv_resample
722
  self.num_classes = num_classes
 
723
  self.num_heads = num_heads
724
  self.num_head_channels = num_head_channels
725
  self.num_heads_upsample = num_heads_upsample
 
793
  dropout,
794
  out_channels=mult * model_channels,
795
  dims=dims,
 
796
  use_scale_shift_norm=use_scale_shift_norm,
797
  )
798
  ]
 
812
  dim_head,
813
  context_dim=context_dim,
814
  depth=transformer_depth,
 
815
  ip_dim=self.ip_dim,
816
  ip_weight=self.ip_weight,
817
  )
 
829
  dropout,
830
  out_channels=out_ch,
831
  dims=dims,
 
832
  use_scale_shift_norm=use_scale_shift_norm,
833
  down=True,
834
  )
 
855
  time_embed_dim,
856
  dropout,
857
  dims=dims,
 
858
  use_scale_shift_norm=use_scale_shift_norm,
859
  ),
860
  SpatialTransformer3D(
 
863
  dim_head,
864
  context_dim=context_dim,
865
  depth=transformer_depth,
 
866
  ip_dim=self.ip_dim,
867
  ip_weight=self.ip_weight,
868
  ),
 
871
  time_embed_dim,
872
  dropout,
873
  dims=dims,
 
874
  use_scale_shift_norm=use_scale_shift_norm,
875
  ),
876
  )
 
887
  dropout,
888
  out_channels=model_channels * mult,
889
  dims=dims,
 
890
  use_scale_shift_norm=use_scale_shift_norm,
891
  )
892
  ]
 
906
  dim_head,
907
  context_dim=context_dim,
908
  depth=transformer_depth,
 
909
  ip_dim=self.ip_dim,
910
  ip_weight=self.ip_weight,
911
  )
 
919
  dropout,
920
  out_channels=out_ch,
921
  dims=dims,
 
922
  use_scale_shift_norm=use_scale_shift_norm,
923
  up=True,
924
  )