Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
from inspect import isfunction | |
import math | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
from pdb import set_trace as st | |
from ldm.modules.attention import MemoryEfficientCrossAttention | |
from .dit_decoder import DiT2 | |
class DiT3D(DiT2): | |
def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4, class_dropout_prob=0.1, num_classes=1000, learn_sigma=True, mixing_logit_init=-3, mixed_prediction=True, context_dim=False, roll_out=False, plane_n=3, return_all_layers=False, in_plane_attention=True, vit_blk=...): | |
super().__init__(input_size, patch_size, in_channels, hidden_size, depth, num_heads, mlp_ratio, class_dropout_prob, num_classes, learn_sigma, mixing_logit_init, mixed_prediction, context_dim, roll_out, plane_n, return_all_layers, in_plane_attention, vit_blk) | |
# follow point infinity, add "write" CA block per 6 blocks | |
# 25/4/2024, cascade a "read&write" block after the DiT base model. | |
self.read_ca = MemoryEfficientCrossAttention(hidden_size, context_dim) | |
self.point_infinity_blocks = nn.ModuleList([ | |
vit_blk(hidden_size, num_heads, mlp_ratio=mlp_ratio) | |
for _ in range(2) | |
]) | |
def initialize_weights(self): | |
super().initialize_weights() | |
# Zero-out adaLN modulation layers in DiT blocks: | |
# ! no final layer anymore | |
for block in self.point_infinity_blocks: | |
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
def forward(self, c, *args, **kwargs): | |
x_base = super().forward(c, *args, **kwargs) # base latent | |
# add read&write block | |
################################################################################# | |
# DiT3D Configs # | |
################################################################################# | |
def DiT3DXL_2(**kwargs): | |
return DiT3D(depth=28, | |
hidden_size=1152, | |
patch_size=2, | |
num_heads=16, | |
**kwargs) | |
def DiT3DXL_2_half(**kwargs): | |
return DiT3D(depth=28 // 2, | |
hidden_size=1152, | |
patch_size=2, | |
num_heads=16, | |
**kwargs) | |
def DiT3DXL_4(**kwargs): | |
return DiT3D(depth=28, | |
hidden_size=1152, | |
patch_size=4, | |
num_heads=16, | |
**kwargs) | |
def DiT3DXL_8(**kwargs): | |
return DiT3D(depth=28, | |
hidden_size=1152, | |
patch_size=8, | |
num_heads=16, | |
**kwargs) | |
def DiT3DL_2(**kwargs): | |
return DiT3D(depth=24, | |
hidden_size=1024, | |
patch_size=2, | |
num_heads=16, | |
**kwargs) | |
def DiT3DL_2_half(**kwargs): | |
return DiT3D(depth=24 // 2, | |
hidden_size=1024, | |
patch_size=2, | |
num_heads=16, | |
**kwargs) | |
def DiT3DL_4(**kwargs): | |
return DiT3D(depth=24, | |
hidden_size=1024, | |
patch_size=4, | |
num_heads=16, | |
**kwargs) | |
def DiT3DL_8(**kwargs): | |
return DiT3D(depth=24, | |
hidden_size=1024, | |
patch_size=8, | |
num_heads=16, | |
**kwargs) | |
def DiT3DB_2(**kwargs): | |
return DiT3D(depth=12, | |
hidden_size=768, | |
patch_size=2, | |
num_heads=12, | |
**kwargs) | |
def DiT3DB_4(**kwargs): | |
return DiT3D(depth=12, | |
hidden_size=768, | |
patch_size=4, | |
num_heads=12, | |
**kwargs) | |
def DiT3DB_8(**kwargs): | |
return DiT3D(depth=12, | |
hidden_size=768, | |
patch_size=8, | |
num_heads=12, | |
**kwargs) | |
def DiT3DB_16(**kwargs): # ours cfg | |
return DiT3D(depth=12, | |
hidden_size=768, | |
patch_size=16, | |
num_heads=12, | |
**kwargs) | |
def DiT3DS_2(**kwargs): | |
return DiT3D(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) | |
def DiT3DS_4(**kwargs): | |
return DiT3D(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) | |
def DiT3DS_8(**kwargs): | |
return DiT3D(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) | |
DiT3Dmodels = { | |
'DiT3D-XL/2': DiT3DXL_2, | |
'DiT3D-XL/2/half': DiT3DXL_2_half, | |
'DiT3D-XL/4': DiT3DXL_4, | |
'DiT3D-XL/8': DiT3DXL_8, | |
'DiT3D-L/2': DiT3DL_2, | |
'DiT3D-L/2/half': DiT3DL_2_half, | |
'DiT3D-L/4': DiT3DL_4, | |
'DiT3D-L/8': DiT3DL_8, | |
'DiT3D-B/2': DiT3DB_2, | |
'DiT3D-B/4': DiT3DB_4, | |
'DiT3D-B/8': DiT3DB_8, | |
'DiT3D-B/16': DiT3DB_16, | |
'DiT3D-S/2': DiT3DS_2, | |
'DiT3D-S/4': DiT3DS_4, | |
'DiT3D-S/8': DiT3DS_8, | |
} | |