import torch.nn as nn from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from pdb import set_trace as st from ldm.modules.attention import MemoryEfficientCrossAttention from .dit_decoder import DiT2 class DiT3D(DiT2): def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, mixing_logit_init=-3, mixed_prediction=True, context_dim=False, roll_out=False, plane_n=3, return_all_layers=False, in_plane_attention=True, vit_blk=...): super().__init__(input_size, patch_size, in_channels, hidden_size, depth, num_heads, mlp_ratio, class_dropout_prob, num_classes, learn_sigma, mixing_logit_init, mixed_prediction, context_dim, roll_out, plane_n, return_all_layers, in_plane_attention, vit_blk) # follow point infinity, add "write" CA block per 6 blocks # 25/4/2024, cascade a "read&write" block after the DiT base model. self.read_ca = MemoryEfficientCrossAttention(hidden_size, context_dim) self.point_infinity_blocks = nn.ModuleList([ vit_blk(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(2) ]) def initialize_weights(self): super().initialize_weights() # Zero-out adaLN modulation layers in DiT blocks: # ! no final layer anymore for block in self.point_infinity_blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) def forward(self, c, *args, **kwargs): x_base = super().forward(c, *args, **kwargs) # base latent # add read&write block ################################################################################# # DiT3D Configs # ################################################################################# def DiT3DXL_2(**kwargs): return DiT3D(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) def DiT3DXL_2_half(**kwargs): return DiT3D(depth=28 // 2, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) def DiT3DXL_4(**kwargs): return DiT3D(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) def DiT3DXL_8(**kwargs): return DiT3D(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) def DiT3DL_2(**kwargs): return DiT3D(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) def DiT3DL_2_half(**kwargs): return DiT3D(depth=24 // 2, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) def DiT3DL_4(**kwargs): return DiT3D(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) def DiT3DL_8(**kwargs): return DiT3D(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) def DiT3DB_2(**kwargs): return DiT3D(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) def DiT3DB_4(**kwargs): return DiT3D(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) def DiT3DB_8(**kwargs): return DiT3D(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) def DiT3DB_16(**kwargs): # ours cfg return DiT3D(depth=12, hidden_size=768, patch_size=16, num_heads=12, **kwargs) def DiT3DS_2(**kwargs): return DiT3D(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) def DiT3DS_4(**kwargs): return DiT3D(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) def DiT3DS_8(**kwargs): return DiT3D(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) DiT3Dmodels = { 'DiT3D-XL/2': DiT3DXL_2, 'DiT3D-XL/2/half': DiT3DXL_2_half, 'DiT3D-XL/4': DiT3DXL_4, 'DiT3D-XL/8': DiT3DXL_8, 'DiT3D-L/2': DiT3DL_2, 'DiT3D-L/2/half': DiT3DL_2_half, 'DiT3D-L/4': DiT3DL_4, 'DiT3D-L/8': DiT3DL_8, 'DiT3D-B/2': DiT3DB_2, 'DiT3D-B/4': DiT3DB_4, 'DiT3D-B/8': DiT3DB_8, 'DiT3D-B/16': DiT3DB_16, 'DiT3D-S/2': DiT3DS_2, 'DiT3D-S/4': DiT3DS_4, 'DiT3D-S/8': DiT3DS_8, }