from dataclasses import dataclass from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn from diffusers.utils import BaseOutput, is_torch_version from diffusers.utils.torch_utils import randn_tensor from diffusers.models.attention_processor import SpatialNorm from .unet_causal_3d_blocks import ( CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d, ) @dataclass class DecoderOutput(BaseOutput): r""" Output of decoding method. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. """ sample: torch.FloatTensor class EncoderCausal3D(nn.Module): r""" The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, time_compression_ratio: int = 4, spatial_compression_ratio: int = 8, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1) self.mid_block = None self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) num_time_downsample_layers = int(np.log2(time_compression_ratio)) if time_compression_ratio == 4: add_spatial_downsample = bool(i < num_spatial_downsample_layers) add_time_downsample = bool( i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block ) else: raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) downsample_stride_T = (2,) if add_time_downsample else (1,) downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) down_block = get_down_block3d( down_block_type, num_layers=self.layers_per_block, in_channels=input_channel, out_channels=output_channel, add_downsample=bool(add_spatial_downsample or add_time_downsample), downsample_stride=downsample_stride, resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=None, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlockCausal3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, add_attention=mid_block_add_attention, ) # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3) def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `EncoderCausal3D` class.""" assert len(sample.shape) == 5, "The input tensor should have 5 dimensions" sample = self.conv_in(sample) # down for down_block in self.down_blocks: sample = down_block(sample) # middle sample = self.mid_block(sample) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class DecoderCausal3D(nn.Module): r""" The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample. """ def __init__( self, in_channels: int = 3, out_channels: int = 3, up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: int = 2, norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial mid_block_add_attention=True, time_compression_ratio: int = 4, spatial_compression_ratio: int = 8, ): super().__init__() self.layers_per_block = layers_per_block self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1) self.mid_block = None self.up_blocks = nn.ModuleList([]) temb_channels = in_channels if norm_type == "spatial" else None # mid self.mid_block = UNetMidBlockCausal3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, add_attention=mid_block_add_attention, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) num_time_upsample_layers = int(np.log2(time_compression_ratio)) if time_compression_ratio == 4: add_spatial_upsample = bool(i < num_spatial_upsample_layers) add_time_upsample = bool( i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block ) else: raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.") upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) upsample_scale_factor_T = (2,) if add_time_upsample else (1,) upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW) up_block = get_up_block3d( up_block_type, num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, out_channels=output_channel, prev_output_channel=None, add_upsample=bool(add_spatial_upsample or add_time_upsample), upsample_scale_factor=upsample_scale_factor, resnet_eps=1e-6, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=output_channel, temb_channels=temb_channels, resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_type == "spatial": self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) else: self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3) self.gradient_checkpointing = False def forward( self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `DecoderCausal3D` class.""" assert len(sample.shape) == 5, "The input tensor should have 5 dimensions." sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False, ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False, ) else: # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, latent_embeds ) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) else: # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): if parameters.ndim == 3: dim = 2 # (B, L, C) elif parameters.ndim == 5 or parameters.ndim == 4: dim = 1 # (B, C, T, H ,W) / (B, C, H, W) else: raise NotImplementedError self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean, device=self.parameters.device, dtype=self.parameters.dtype ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype sample = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * sample return x def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) else: reduce_dim = list(range(1, self.mean.ndim)) if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=reduce_dim, ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=reduce_dim, ) def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self) -> torch.Tensor: return self.mean