erichardson commited on
Commit
43d3c68
1 Parent(s): 66d4261

VAE: Support more configurations for Encoder and Decoder blocks

Browse files

VAE: Define encoder compress-all block with channel multiplier

VAE: Support residual connection in the decoder

VAE: Refactor CausalConv3d parameters

lint

xora/models/autoencoders/causal_conv3d.py CHANGED
@@ -11,6 +11,8 @@ class CausalConv3d(nn.Module):
11
  out_channels,
12
  kernel_size: int = 3,
13
  stride: Union[int, Tuple[int]] = 1,
 
 
14
  **kwargs,
15
  ):
16
  super().__init__()
@@ -21,7 +23,6 @@ class CausalConv3d(nn.Module):
21
  kernel_size = (kernel_size, kernel_size, kernel_size)
22
  self.time_kernel_size = kernel_size[0]
23
 
24
- dilation = kwargs.pop("dilation", 1)
25
  dilation = (dilation, 1, 1)
26
 
27
  height_pad = kernel_size[1] // 2
@@ -36,6 +37,7 @@ class CausalConv3d(nn.Module):
36
  dilation=dilation,
37
  padding=padding,
38
  padding_mode="zeros",
 
39
  )
40
 
41
  def forward(self, x, causal: bool = True):
 
11
  out_channels,
12
  kernel_size: int = 3,
13
  stride: Union[int, Tuple[int]] = 1,
14
+ dilation: int = 1,
15
+ groups: int = 1,
16
  **kwargs,
17
  ):
18
  super().__init__()
 
23
  kernel_size = (kernel_size, kernel_size, kernel_size)
24
  self.time_kernel_size = kernel_size[0]
25
 
 
26
  dilation = (dilation, 1, 1)
27
 
28
  height_pad = kernel_size[1] // 2
 
37
  dilation=dilation,
38
  padding=padding,
39
  padding_mode="zeros",
40
+ groups=groups,
41
  )
42
 
43
  def forward(self, x, causal: bool = True):
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -78,7 +78,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
78
  dims=config["dims"],
79
  in_channels=config.get("in_channels", 3),
80
  out_channels=config["latent_channels"],
81
- blocks=config["blocks"],
82
  patch_size=config.get("patch_size", 1),
83
  latent_log_var=latent_log_var,
84
  norm_layer=config.get("norm_layer", "group_norm"),
@@ -88,7 +88,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
88
  dims=config["dims"],
89
  in_channels=config["latent_channels"],
90
  out_channels=config.get("out_channels", 3),
91
- blocks=config["blocks"],
92
  patch_size=config.get("patch_size", 1),
93
  norm_layer=config.get("norm_layer", "group_norm"),
94
  causal=config.get("causal_decoder", False),
@@ -112,7 +112,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
112
  out_channels=self.decoder.conv_out.out_channels
113
  // self.decoder.patch_size**2,
114
  latent_channels=self.decoder.conv_in.in_channels,
115
- blocks=self.encoder.blocks_desc,
 
116
  scaling_factor=1.0,
117
  norm_layer=self.encoder.norm_layer,
118
  patch_size=self.encoder.patch_size,
@@ -242,7 +243,7 @@ class Encoder(nn.Module):
242
  dims: Union[int, Tuple[int, int]] = 3,
243
  in_channels: int = 3,
244
  out_channels: int = 3,
245
- blocks: List[Tuple[str, int]] = [("res_x", 1)],
246
  base_channels: int = 128,
247
  norm_num_groups: int = 32,
248
  patch_size: Union[int, Tuple[int]] = 1,
@@ -271,20 +272,22 @@ class Encoder(nn.Module):
271
 
272
  self.down_blocks = nn.ModuleList([])
273
 
274
- for block_name, num_layers in blocks:
275
  input_channel = output_channel
 
 
276
 
277
  if block_name == "res_x":
278
  block = UNetMidBlock3D(
279
  dims=dims,
280
  in_channels=input_channel,
281
- num_layers=num_layers,
282
  resnet_eps=1e-6,
283
  resnet_groups=norm_num_groups,
284
  norm_layer=norm_layer,
285
  )
286
  elif block_name == "res_x_y":
287
- output_channel = 2 * output_channel
288
  block = ResnetBlock3D(
289
  dims=dims,
290
  in_channels=input_channel,
@@ -320,6 +323,16 @@ class Encoder(nn.Module):
320
  stride=(2, 2, 2),
321
  causal=True,
322
  )
 
 
 
 
 
 
 
 
 
 
323
  else:
324
  raise ValueError(f"unknown block: {block_name}")
325
 
@@ -421,7 +434,7 @@ class Decoder(nn.Module):
421
  dims,
422
  in_channels: int = 3,
423
  out_channels: int = 3,
424
- blocks: List[Tuple[str, int]] = [("res_x", 1)],
425
  base_channels: int = 128,
426
  layers_per_block: int = 2,
427
  norm_num_groups: int = 32,
@@ -433,9 +446,15 @@ class Decoder(nn.Module):
433
  self.patch_size = patch_size
434
  self.layers_per_block = layers_per_block
435
  out_channels = out_channels * patch_size**2
436
- num_channel_doubles = len([x for x in blocks if x[0] == "res_x_y"])
437
- output_channel = base_channels * 2**num_channel_doubles
438
  self.causal = causal
 
 
 
 
 
 
 
 
439
 
440
  self.conv_in = make_conv_nd(
441
  dims,
@@ -449,20 +468,22 @@ class Decoder(nn.Module):
449
 
450
  self.up_blocks = nn.ModuleList([])
451
 
452
- for block_name, num_layers in list(reversed(blocks)):
453
  input_channel = output_channel
 
 
454
 
455
  if block_name == "res_x":
456
  block = UNetMidBlock3D(
457
  dims=dims,
458
  in_channels=input_channel,
459
- num_layers=num_layers,
460
  resnet_eps=1e-6,
461
  resnet_groups=norm_num_groups,
462
  norm_layer=norm_layer,
463
  )
464
  elif block_name == "res_x_y":
465
- output_channel = output_channel // 2
466
  block = ResnetBlock3D(
467
  dims=dims,
468
  in_channels=input_channel,
@@ -481,7 +502,10 @@ class Decoder(nn.Module):
481
  )
482
  elif block_name == "compress_all":
483
  block = DepthToSpaceUpsample(
484
- dims=dims, in_channels=input_channel, stride=(2, 2, 2)
 
 
 
485
  )
486
  else:
487
  raise ValueError(f"unknown layer: {block_name}")
@@ -590,7 +614,7 @@ class UNetMidBlock3D(nn.Module):
590
 
591
 
592
  class DepthToSpaceUpsample(nn.Module):
593
- def __init__(self, dims, in_channels, stride):
594
  super().__init__()
595
  self.stride = stride
596
  self.out_channels = np.prod(stride) * in_channels
@@ -602,8 +626,21 @@ class DepthToSpaceUpsample(nn.Module):
602
  stride=1,
603
  causal=True,
604
  )
 
605
 
606
  def forward(self, x, causal: bool = True):
 
 
 
 
 
 
 
 
 
 
 
 
607
  x = self.conv(x, causal=causal)
608
  x = rearrange(
609
  x,
@@ -614,6 +651,8 @@ class DepthToSpaceUpsample(nn.Module):
614
  )
615
  if self.stride[0] == 2:
616
  x = x[:, :, 1:, :, :]
 
 
617
  return x
618
 
619
 
@@ -647,7 +686,6 @@ class ResnetBlock3D(nn.Module):
647
  dims: Union[int, Tuple[int, int]],
648
  in_channels: int,
649
  out_channels: Optional[int] = None,
650
- conv_shortcut: bool = False,
651
  dropout: float = 0.0,
652
  groups: int = 32,
653
  eps: float = 1e-6,
@@ -657,7 +695,6 @@ class ResnetBlock3D(nn.Module):
657
  self.in_channels = in_channels
658
  out_channels = in_channels if out_channels is None else out_channels
659
  self.out_channels = out_channels
660
- self.use_conv_shortcut = conv_shortcut
661
 
662
  if norm_layer == "group_norm":
663
  self.norm1 = nn.GroupNorm(
 
78
  dims=config["dims"],
79
  in_channels=config.get("in_channels", 3),
80
  out_channels=config["latent_channels"],
81
+ blocks=config.get("encoder_blocks", config.get("blocks")),
82
  patch_size=config.get("patch_size", 1),
83
  latent_log_var=latent_log_var,
84
  norm_layer=config.get("norm_layer", "group_norm"),
 
88
  dims=config["dims"],
89
  in_channels=config["latent_channels"],
90
  out_channels=config.get("out_channels", 3),
91
+ blocks=config.get("decoder_blocks", config.get("blocks")),
92
  patch_size=config.get("patch_size", 1),
93
  norm_layer=config.get("norm_layer", "group_norm"),
94
  causal=config.get("causal_decoder", False),
 
112
  out_channels=self.decoder.conv_out.out_channels
113
  // self.decoder.patch_size**2,
114
  latent_channels=self.decoder.conv_in.in_channels,
115
+ encoder_blocks=self.encoder.blocks_desc,
116
+ decoder_blocks=self.decoder.blocks_desc,
117
  scaling_factor=1.0,
118
  norm_layer=self.encoder.norm_layer,
119
  patch_size=self.encoder.patch_size,
 
243
  dims: Union[int, Tuple[int, int]] = 3,
244
  in_channels: int = 3,
245
  out_channels: int = 3,
246
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
247
  base_channels: int = 128,
248
  norm_num_groups: int = 32,
249
  patch_size: Union[int, Tuple[int]] = 1,
 
272
 
273
  self.down_blocks = nn.ModuleList([])
274
 
275
+ for block_name, block_params in blocks:
276
  input_channel = output_channel
277
+ if isinstance(block_params, int):
278
+ block_params = {"num_layers": block_params}
279
 
280
  if block_name == "res_x":
281
  block = UNetMidBlock3D(
282
  dims=dims,
283
  in_channels=input_channel,
284
+ num_layers=block_params["num_layers"],
285
  resnet_eps=1e-6,
286
  resnet_groups=norm_num_groups,
287
  norm_layer=norm_layer,
288
  )
289
  elif block_name == "res_x_y":
290
+ output_channel = block_params.get("multiplier", 2) * output_channel
291
  block = ResnetBlock3D(
292
  dims=dims,
293
  in_channels=input_channel,
 
323
  stride=(2, 2, 2),
324
  causal=True,
325
  )
326
+ elif block_name == "compress_all_x_y":
327
+ output_channel = block_params.get("multiplier", 2) * output_channel
328
+ block = make_conv_nd(
329
+ dims=dims,
330
+ in_channels=input_channel,
331
+ out_channels=output_channel,
332
+ kernel_size=3,
333
+ stride=(2, 2, 2),
334
+ causal=True,
335
+ )
336
  else:
337
  raise ValueError(f"unknown block: {block_name}")
338
 
 
434
  dims,
435
  in_channels: int = 3,
436
  out_channels: int = 3,
437
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
438
  base_channels: int = 128,
439
  layers_per_block: int = 2,
440
  norm_num_groups: int = 32,
 
446
  self.patch_size = patch_size
447
  self.layers_per_block = layers_per_block
448
  out_channels = out_channels * patch_size**2
 
 
449
  self.causal = causal
450
+ self.blocks_desc = blocks
451
+
452
+ # Compute output channel to be product of all channel-multiplier blocks
453
+ output_channel = base_channels
454
+ for block_name, block_params in list(reversed(blocks)):
455
+ block_params = block_params if isinstance(block_params, dict) else {}
456
+ if block_name == "res_x_y":
457
+ output_channel = output_channel * block_params.get("multiplier", 2)
458
 
459
  self.conv_in = make_conv_nd(
460
  dims,
 
468
 
469
  self.up_blocks = nn.ModuleList([])
470
 
471
+ for block_name, block_params in list(reversed(blocks)):
472
  input_channel = output_channel
473
+ if isinstance(block_params, int):
474
+ block_params = {"num_layers": block_params}
475
 
476
  if block_name == "res_x":
477
  block = UNetMidBlock3D(
478
  dims=dims,
479
  in_channels=input_channel,
480
+ num_layers=block_params["num_layers"],
481
  resnet_eps=1e-6,
482
  resnet_groups=norm_num_groups,
483
  norm_layer=norm_layer,
484
  )
485
  elif block_name == "res_x_y":
486
+ output_channel = output_channel // block_params.get("multiplier", 2)
487
  block = ResnetBlock3D(
488
  dims=dims,
489
  in_channels=input_channel,
 
502
  )
503
  elif block_name == "compress_all":
504
  block = DepthToSpaceUpsample(
505
+ dims=dims,
506
+ in_channels=input_channel,
507
+ stride=(2, 2, 2),
508
+ residual=block_params.get("residual", False),
509
  )
510
  else:
511
  raise ValueError(f"unknown layer: {block_name}")
 
614
 
615
 
616
  class DepthToSpaceUpsample(nn.Module):
617
+ def __init__(self, dims, in_channels, stride, residual=False):
618
  super().__init__()
619
  self.stride = stride
620
  self.out_channels = np.prod(stride) * in_channels
 
626
  stride=1,
627
  causal=True,
628
  )
629
+ self.residual = residual
630
 
631
  def forward(self, x, causal: bool = True):
632
+ if self.residual:
633
+ # Reshape and duplicate the input to match the output shape
634
+ x_in = rearrange(
635
+ x,
636
+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
637
+ p1=self.stride[0],
638
+ p2=self.stride[1],
639
+ p3=self.stride[2],
640
+ )
641
+ x_in = x_in.repeat(1, np.prod(self.stride), 1, 1, 1)
642
+ if self.stride[0] == 2:
643
+ x_in = x_in[:, :, 1:, :, :]
644
  x = self.conv(x, causal=causal)
645
  x = rearrange(
646
  x,
 
651
  )
652
  if self.stride[0] == 2:
653
  x = x[:, :, 1:, :, :]
654
+ if self.residual:
655
+ x = x + x_in
656
  return x
657
 
658
 
 
686
  dims: Union[int, Tuple[int, int]],
687
  in_channels: int,
688
  out_channels: Optional[int] = None,
 
689
  dropout: float = 0.0,
690
  groups: int = 32,
691
  eps: float = 1e-6,
 
695
  self.in_channels = in_channels
696
  out_channels = in_channels if out_channels is None else out_channels
697
  self.out_channels = out_channels
 
698
 
699
  if norm_layer == "group_norm":
700
  self.norm1 = nn.GroupNorm(