GaussianAnything-AIGC3D / dit /dit_decoder_3d.py
yslan's picture
init
7f51798
import torch.nn as nn
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from pdb import set_trace as st
from ldm.modules.attention import MemoryEfficientCrossAttention
from .dit_decoder import DiT2
class DiT3D(DiT2):
def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, mixing_logit_init=-3, mixed_prediction=True, context_dim=False, roll_out=False, plane_n=3, return_all_layers=False, in_plane_attention=True, vit_blk=...):
super().__init__(input_size, patch_size, in_channels, hidden_size, depth, num_heads, mlp_ratio, class_dropout_prob, num_classes, learn_sigma, mixing_logit_init, mixed_prediction, context_dim, roll_out, plane_n, return_all_layers, in_plane_attention, vit_blk)
# follow point infinity, add "write" CA block per 6 blocks
# 25/4/2024, cascade a "read&write" block after the DiT base model.
self.read_ca = MemoryEfficientCrossAttention(hidden_size, context_dim)
self.point_infinity_blocks = nn.ModuleList([
vit_blk(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(2)
])
def initialize_weights(self):
super().initialize_weights()
# Zero-out adaLN modulation layers in DiT blocks:
# ! no final layer anymore
for block in self.point_infinity_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, c, *args, **kwargs):
x_base = super().forward(c, *args, **kwargs) # base latent
# add read&write block
#################################################################################
# DiT3D Configs #
#################################################################################
def DiT3DXL_2(**kwargs):
return DiT3D(depth=28,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT3DXL_2_half(**kwargs):
return DiT3D(depth=28 // 2,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT3DXL_4(**kwargs):
return DiT3D(depth=28,
hidden_size=1152,
patch_size=4,
num_heads=16,
**kwargs)
def DiT3DXL_8(**kwargs):
return DiT3D(depth=28,
hidden_size=1152,
patch_size=8,
num_heads=16,
**kwargs)
def DiT3DL_2(**kwargs):
return DiT3D(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT3DL_2_half(**kwargs):
return DiT3D(depth=24 // 2,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT3DL_4(**kwargs):
return DiT3D(depth=24,
hidden_size=1024,
patch_size=4,
num_heads=16,
**kwargs)
def DiT3DL_8(**kwargs):
return DiT3D(depth=24,
hidden_size=1024,
patch_size=8,
num_heads=16,
**kwargs)
def DiT3DB_2(**kwargs):
return DiT3D(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
def DiT3DB_4(**kwargs):
return DiT3D(depth=12,
hidden_size=768,
patch_size=4,
num_heads=12,
**kwargs)
def DiT3DB_8(**kwargs):
return DiT3D(depth=12,
hidden_size=768,
patch_size=8,
num_heads=12,
**kwargs)
def DiT3DB_16(**kwargs): # ours cfg
return DiT3D(depth=12,
hidden_size=768,
patch_size=16,
num_heads=12,
**kwargs)
def DiT3DS_2(**kwargs):
return DiT3D(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
def DiT3DS_4(**kwargs):
return DiT3D(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
def DiT3DS_8(**kwargs):
return DiT3D(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
DiT3Dmodels = {
'DiT3D-XL/2': DiT3DXL_2,
'DiT3D-XL/2/half': DiT3DXL_2_half,
'DiT3D-XL/4': DiT3DXL_4,
'DiT3D-XL/8': DiT3DXL_8,
'DiT3D-L/2': DiT3DL_2,
'DiT3D-L/2/half': DiT3DL_2_half,
'DiT3D-L/4': DiT3DL_4,
'DiT3D-L/8': DiT3DL_8,
'DiT3D-B/2': DiT3DB_2,
'DiT3D-B/4': DiT3DB_4,
'DiT3D-B/8': DiT3DB_8,
'DiT3D-B/16': DiT3DB_16,
'DiT3D-S/2': DiT3DS_2,
'DiT3D-S/4': DiT3DS_4,
'DiT3D-S/8': DiT3DS_8,
}