yslan's picture
init
7f51798
raw
history blame
62.7 kB
import torch.nn as nn
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from pdb import set_trace as st
from ldm.modules.attention import MemoryEfficientCrossAttention
from .dit_models_xformers import *
# from apex.normalization import FusedLayerNorm as LayerNorm
from torch.nn import LayerNorm
# from apex.normalization import FusedRMSNorm as RMSNorm
try:
from apex.normalization import FusedRMSNorm as RMSNorm
except:
from dit.norm import RMSNorm
from timm.models.vision_transformer import Mlp
from vit.vit_triplane import XYZPosEmbed
from .dit_trilatent import DiT_PCD_PixelArt
class DiT_I23D(DiT):
# DiT with 3D_aware operations
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlock,
final_layer_blk=T2IFinalLayer,
enable_rope=False,
):
# st()
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
mixed_prediction, context_dim, roll_out, vit_blk,
T2IFinalLayer, enable_rope=enable_rope)
assert self.roll_out
# if context_dim is not None:
# self.dino_proj = CaptionEmbedder(context_dim,
self.clip_ctx_dim = 1024 # vit-l
# self.dino_proj = CaptionEmbedder(self.clip_ctx_dim, # ! dino-vitl/14 here, for img-cond
self.dino_proj = CaptionEmbedder(context_dim, # ! dino-vitb/14 here, for MV-cond. hard coded for now...
# self.dino_proj = CaptionEmbedder(1024, # ! dino-vitb/14 here, for MV-cond. hard coded for now...
hidden_size,
act_layer=approx_gelu)
self.clip_spatial_proj = CaptionEmbedder(1024, # clip_I-L
hidden_size,
act_layer=approx_gelu)
def init_PE_3D_aware(self):
self.pos_embed = nn.Parameter(torch.zeros(
1, self.plane_n * self.x_embedder.num_patches, self.embed_dim),
requires_grad=False)
# Initialize (and freeze) pos_embed by sin-cos embedding:
p = int(self.x_embedder.num_patches**0.5)
D = self.pos_embed.shape[-1]
grid_size = (self.plane_n, p * p) # B n HW C
pos_embed = get_2d_sincos_pos_embed(D, grid_size).reshape(
self.plane_n * p * p, D) # H*W, D
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0))
def initialize_weights(self):
super().initialize_weights()
# ! add 3d-aware PE
self.init_PE_3D_aware()
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# context = self.clip_text_proj(context)
clip_cls_token = self.clip_text_proj(context['vector'])
clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
# ! todo, return spatial clip features.
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens.
# x = rearrange(x, '(b n) l c ->b (n l) c', n=3)
# assert context.ndim == 2
# if isinstance(context, dict):
# context = context['crossattn'] # sgm conditioner compat
# c = t + context
# else:
# c = t # BS 1024
for blk_idx, block in enumerate(self.blocks):
x = block(x, t, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
# x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)
# cast to float32 for better accuracy
x = x.to(torch.float32).contiguous()
return x
# ! compat issue
def forward_with_cfg(self, x, t, context, cfg_scale):
"""
Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
# half = x[: len(x) // 2]
# combined = torch.cat([half, half], dim=0)
eps = self.forward(x, t, context)
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
# eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return eps
class DiT_I23D_PixelArt(DiT_I23D):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk=FinalLayer,
create_cap_embedder=True,
enable_rope=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk,
final_layer_blk,
enable_rope=enable_rope)
# ! a shared one
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
# ! single
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
del self.clip_text_proj
if create_cap_embedder:
self.cap_embedder = nn.Sequential( # TODO, init with zero here.
LayerNorm(pooling_ctx_dim),
nn.Linear(
pooling_ctx_dim,
hidden_size,
),
)
nn.init.constant_(self.cap_embedder[-1].weight, 0)
nn.init.constant_(self.cap_embedder[-1].bias, 0)
else:
self.cap_embedder = nn.Identity() # placeholder
print(self) # check model arch
self.attention_y_norm = RMSNorm(
1024, eps=1e-5
) # https://github.com/Alpha-VLLM/Lumina-T2X/blob/0c8dd6a07a3b7c18da3d91f37b1e00e7ae661293/lumina_t2i/models/model.py#L570C9-L570C61
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# context = self.clip_text_proj(context)
clip_cls_token = self.cap_embedder(context['vector'])
clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
clip_spatial_token = self.attention_y_norm(clip_spatial_token) # avoid re-normalization in each blk
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
# if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens.
# x = rearrange(x, '(b n) l c ->b (n l) c', n=3)
# assert context.ndim == 2
# if isinstance(context, dict):
# context = context['crossattn'] # sgm conditioner compat
# c = t + context
# else:
# c = t # BS 1024
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
# x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)
# cast to float32 for better accuracy
x = x.to(torch.float32).contiguous()
return x
class DiT_I23D_PCD_PixelArt(DiT_I23D_PixelArt):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk=FinalLayer,
create_cap_embedder=True,
use_clay_ca=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk,
final_layer_blk)
self.x_embedder = Mlp(in_features=in_channels,
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=approx_gelu,
drop=0)
del self.pos_embed
self.use_clay_ca = use_clay_ca
if use_clay_ca:
del self.dino_proj # no prepending required.
# add ln_pred and ln_post, as in point-e. (does not help, worse performance)
# self.ln_pre = LayerNorm(hidden_size)
# self.ln_post = LayerNorm(hidden_size)
@staticmethod
def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
rope_scaling_factor: float = 1.0,
ntk_factor: float = 1.0,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with
given dimensions.
This function calculates a frequency tensor with complex exponentials
using the given dimension 'dim' and the end index 'end'. The 'theta'
parameter scales the frequencies. The returned tensor contains complex
values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation.
Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex
exponentials.
"""
theta = theta * ntk_factor
print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
t = t / rope_scaling_factor
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps)
# global condition
if 'caption_vector' in context:
clip_cls_token = self.cap_embedder(context['caption_vector'])
elif 'img_vector' in context:
clip_cls_token = self.cap_embedder(context['img_vector'])
else:
clip_cls_token = 0
# spatial condition
clip_spatial_token, dino_spatial_token = context['img_crossattn'][..., :self.clip_ctx_dim], context['img_crossattn'][..., self.clip_ctx_dim:]
if not self.use_clay_ca:
dino_spatial_token=self.dino_proj(dino_spatial_token)
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
x = self.x_embedder(x)
# add a norm layer here, as in point-e
# x = self.ln_pre(x)
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token, clip_caption_token=context.get('caption_crossattn'))
# add a norm layer here, as in point-e
# x = self.ln_post(x)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
# dino only version
class DiT_I23D_PCD_PixelArt_noclip(DiT_I23D_PixelArt):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArtRMSNormNoClip,
final_layer_blk=FinalLayer,
create_cap_embedder=True,
use_clay_ca=False,
has_caption=False,
# has_rope=False,
rope_scaling_factor: float = 1.0,
ntk_factor: float = 1.0,
enable_rope=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk,
final_layer_blk, enable_rope=enable_rope)
self.x_embedder = Mlp(in_features=in_channels,
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=approx_gelu,
drop=0)
del self.pos_embed
del self.dino_proj
self.enable_rope = enable_rope
if self.enable_rope: # implementation copied from Lumina-T2X code base
self.freqs_cis = DiT_I23D_PCD_PixelArt.precompute_freqs_cis(
hidden_size // num_heads,
40000,
rope_scaling_factor=rope_scaling_factor,
ntk_factor=ntk_factor,
)
else:
self.freqs_cis = None
self.rope_scaling_factor = rope_scaling_factor
self.ntk_factor = ntk_factor
self.use_clay_ca = use_clay_ca
self.has_caption = has_caption
pooled_vector_dim = context_dim
if has_caption:
pooled_vector_dim += 768
self.pooled_vec_embedder = nn.Sequential( # TODO, init with zero here.
LayerNorm(pooled_vector_dim),
nn.Linear(
pooled_vector_dim,
hidden_size,
),
)
nn.init.constant_(self.pooled_vec_embedder[-1].weight, 0)
nn.init.constant_(self.pooled_vec_embedder[-1].bias, 0)
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps)
# clip_cls_token = self.cap_embedder(context['vector'])
# clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
# dino_spatial_token = context['crossattn']
dino_spatial_token = context['img_crossattn']
dino_pooled_vector = context['img_vector']
if self.has_caption:
clip_caption_token = context.get('caption_crossattn')
pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector
else:
clip_caption_token = None
pooled_vector = dino_pooled_vector
t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
x = self.x_embedder(x)
freqs_cis = None
if self.enable_rope:
freqs_cis=self.freqs_cis[: x.size(1)]
# add a norm layer here, as in point-e
# x = self.ln_pre(x)
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_caption_token=clip_caption_token, freqs_cis=freqs_cis)
# add a norm layer here, as in point-e
# x = self.ln_post(x)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
# xyz-diff
# xyz-cond tex diff
class DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(DiT_I23D_PCD_PixelArt):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk=FinalLayer,
create_cap_embedder=True,
use_pe_cond=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk,
final_layer_blk)
self.use_pe_cond = use_pe_cond
self.x_embedder = Mlp(in_features=in_channels+3*(1-use_pe_cond),
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=approx_gelu,
drop=0)
if use_pe_cond:
self.xyz_pos_embed = XYZPosEmbed(hidden_size)
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps)
clip_cls_token = self.cap_embedder(context['vector'])
clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
fps_xyz = context['fps-xyz']
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
if self.use_pe_cond:
x = self.x_embedder(x) + self.xyz_pos_embed(fps_xyz) # point-wise addition
else: # use concat to add info
x = torch.cat([fps_xyz, x], dim=-1)
x = self.x_embedder(x)
# add a norm layer here, as in point-e
# x = self.ln_pre(x)
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token)
# add a norm layer here, as in point-e
# x = self.ln_post(x)
x = self.final_layer(x, t) # no loss on the xyz side
x = x.to(torch.float32).contiguous()
return x
# xyz-cond tex diff, but clay
class DiT_I23D_PCD_PixelArt_noclip_clay_stage2(DiT_I23D_PCD_PixelArt_noclip):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk=FinalLayer,
create_cap_embedder=True,
use_pe_cond=False,
has_caption=False,
use_clay_ca=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk,
final_layer_blk, use_clay_ca=use_clay_ca, has_caption=has_caption)
self.has_caption = False
self.use_pe_cond = use_pe_cond
self.x_embedder = Mlp(in_features=in_channels+3*(1-use_pe_cond),
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=approx_gelu,
drop=0)
if use_pe_cond:
self.xyz_pos_embed = XYZPosEmbed(hidden_size)
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
dino_spatial_token = context['img_crossattn']
dino_pooled_vector = context['img_vector']
if self.has_caption:
clip_caption_token = context.get('caption_crossattn')
pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector
else:
clip_caption_token = None
pooled_vector = dino_pooled_vector
t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
fps_xyz = context['fps-xyz']
if self.use_pe_cond:
x = self.x_embedder(x) + self.xyz_pos_embed(fps_xyz) # point-wise addition
else: # use concat to add info
x = torch.cat([fps_xyz, x], dim=-1)
x = self.x_embedder(x)
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_caption_token=clip_caption_token)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
class DiT_I23D_PixelArt_MVCond(DiT_I23D_PixelArt):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
create_cap_embedder=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim,
pooling_ctx_dim, roll_out, ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk, create_cap_embedder=create_cap_embedder)
# support multi-view img condition
# DINO handles global pooling here; clip takes care of camera-cond with ModLN
# Input DINO concat also + global pool. InstantMesh adopts DINO (but CA).
# expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size.
del self.dino_proj
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# st()
# (Pdb) p context.keys()
# dict_keys(['crossattn', 'vector', 'concat'])
# (Pdb) p context['vector'].shape
# torch.Size([2, 768])
# (Pdb) p context['crossattn'].shape
# torch.Size([2, 256, 1024])
# (Pdb) p context['concat'].shape
# torch.Size([2, 4, 256, 768]) # mv dino spatial features
# ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
# DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
for blk_idx, block in enumerate(self.blocks):
# x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# ! DINO tokens for CA, CLIP tokens for append here.
x = block(x, t0, dino_spatial_token=clip_spatial_token, clip_spatial_token=dino_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
x = x.to(torch.float32).contiguous()
return x
class DiT_I23D_PixelArt_MVCond_noClip(DiT_I23D_PixelArt):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
create_cap_embedder=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim,
pooling_ctx_dim, roll_out,
ImageCondDiTBlockPixelArtRMSNormNoClip,
final_layer_blk,
create_cap_embedder=create_cap_embedder)
# support multi-view img condition
# DINO handles global pooling here; clip takes care of camera-cond with ModLN
# Input DINO concat also + global pool. InstantMesh adopts DINO (but CA).
# expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size.
del self.dino_proj
del self.clip_spatial_proj, self.cap_embedder # no clip required
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# st()
# (Pdb) p context.keys()
# dict_keys(['crossattn', 'vector', 'concat'])
# (Pdb) p context['vector'].shape
# torch.Size([2, 768])
# (Pdb) p context['crossattn'].shape
# torch.Size([2, 256, 1024])
# (Pdb) p context['concat'].shape
# torch.Size([2, 4, 256, 768]) # mv dino spatial features
# ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
# DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
# clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t = self.t_embedder(timesteps)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
# if self.roll_out: # !
x = rearrange(x, 'b (c n) h w->(b n) c h w',
n=3) # downsample with same conv
x = self.x_embedder(x) # (b n) c h/f w/f
x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
for blk_idx, block in enumerate(self.blocks):
# x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# ! DINO tokens for CA, CLIP tokens for append here.
x = block(x, t0, dino_spatial_token=dino_spatial_token) # (N, T, D)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
if self.roll_out: # move n from L to B axis
x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)
x = self.unpatchify(x) # (N, out_channels, H, W)
if self.roll_out: # move n from L to B axis
x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
x = x.to(torch.float32).contiguous()
return x
# pcd-structured latent ddpm
class DiT_pcd_I23D_PixelArt_MVCond(DiT_I23D_PixelArt_MVCond_noClip):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
create_cap_embedder=False,
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim,
pooling_ctx_dim,
roll_out, ImageCondDiTBlockPixelArtRMSNorm,
final_layer_blk,
create_cap_embedder=create_cap_embedder)
# ! first, normalize xyz from [-0.45,0.45] to [-1,1]
# Then, encode xyz with point fourier feat + MLP projection, serves as PE here.
# a separate MLP for the KL feature
# add them together in the feature space
# use a single MLP (final_layer) to map them back to 16 + 3 dims.
self.x_embedder = Mlp(in_features=in_channels,
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=approx_gelu,
drop=0)
del self.pos_embed
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
t = self.t_embedder(timesteps)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
x = self.x_embedder(x)
for blk_idx, block in enumerate(self.blocks):
# x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D)
# ! DINO tokens for CA, CLIP tokens for append here.
x = block(x, t0, dino_spatial_token=dino_spatial_token)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
class DiT_pcd_I23D_PixelArt_MVCond_clay(DiT_PCD_PixelArt):
# fine-tune the mv model from text conditioned model
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
create_cap_embedder=False, **kwargs
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
# mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
mixed_prediction, context_dim,
# pooling_ctx_dim,
roll_out, vit_blk,
final_layer_blk,)
# create_cap_embedder=create_cap_embedder)
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert context is not None
clip_cls_token = self.cap_embedder(context['caption_vector']) # pooled
t = self.t_embedder(timesteps) + clip_cls_token # (N, D)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
x = self.x_embedder(x)
# ! spatial tokens
dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# assert context.ndim == 2
# if isinstance(context, dict):
# context = context['caption_crossattn'] # sgm conditioner compat
# loop dit block
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, clip_caption_token=context['caption_crossattn'],
dino_spatial_token=dino_spatial_token) # (N, T, D)
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
# cast to float32 for better accuracy
x = x.to(torch.float32).contiguous()
return x
# single-img pretrained clay
class DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt(DiT_I23D_PCD_PixelArt_noclip):
# fine-tune the mv model from text conditioned model
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
create_cap_embedder=False, **kwargs
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
mixed_prediction, context_dim,
pooling_ctx_dim,
roll_out, vit_blk,
final_layer_blk,)
self.has_caption = False
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps)
# clip_cls_token = self.cap_embedder(context['vector'])
# clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
# dino_spatial_token = context['crossattn']
# st()
dino_spatial_token = context['img_crossattn']
dino_pooled_vector = context['img_vector']
dino_mv_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
if self.has_caption:
clip_caption_token = context.get('caption_crossattn')
pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector
else:
clip_caption_token = None
pooled_vector = dino_pooled_vector
t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
x = self.x_embedder(x)
# add a norm layer here, as in point-e
# x = self.ln_pre(x)
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, dino_mv_spatial_token=dino_mv_spatial_token)
# add a norm layer here, as in point-e
# x = self.ln_post(x)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
# stage 2
class DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_stage2(DiT_I23D_PCD_PixelArt_noclip):
# fine-tune the mv model from text conditioned model
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
create_cap_embedder=False, **kwargs
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
mixed_prediction, context_dim,
pooling_ctx_dim,
roll_out, vit_blk,
final_layer_blk,)
self.has_caption = False
self.use_pe_cond = True
self.x_embedder = Mlp(in_features=in_channels+3*(1-self.use_pe_cond),
hidden_features=hidden_size,
out_features=hidden_size,
act_layer=approx_gelu,
drop=0)
if self.use_pe_cond:
self.xyz_pos_embed = XYZPosEmbed(hidden_size)
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps)
# clip_cls_token = self.cap_embedder(context['vector'])
# clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
# dino_spatial_token = context['crossattn']
# st()
dino_spatial_token = context['img_crossattn']
dino_pooled_vector = context['img_vector']
dino_mv_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
if self.has_caption:
clip_caption_token = context.get('caption_crossattn')
pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector
else:
clip_caption_token = None
pooled_vector = dino_pooled_vector
t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
# x = self.x_embedder(x)
fps_xyz = context['fps-xyz']
if self.use_pe_cond:
x = self.x_embedder(x) + self.xyz_pos_embed(fps_xyz) # point-wise addition
else: # use concat to add info
x = torch.cat([fps_xyz, x], dim=-1)
x = self.x_embedder(x)
# add a norm layer here, as in point-e
# x = self.ln_pre(x)
for blk_idx, block in enumerate(self.blocks):
x = block(x, t0, dino_spatial_token=dino_spatial_token, dino_mv_spatial_token=dino_mv_spatial_token)
# add a norm layer here, as in point-e
# x = self.ln_post(x)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
class DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_noi23d(DiT_I23D_PCD_PixelArt_noclip):
# fine-tune the mv model from text conditioned model
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
mixing_logit_init=-3,
mixed_prediction=True,
context_dim=False,
pooling_ctx_dim=768,
roll_out=False,
vit_blk=ImageCondDiTBlockPixelArt,
final_layer_blk=FinalLayer,
create_cap_embedder=False, **kwargs
):
super().__init__(input_size, patch_size, in_channels, hidden_size,
depth, num_heads, mlp_ratio, class_dropout_prob,
num_classes, learn_sigma, mixing_logit_init,
mixed_prediction, context_dim,
pooling_ctx_dim,
roll_out, vit_blk,
final_layer_blk,)
self.has_caption = False
del self.pooled_vec_embedder
def forward(self,
x,
timesteps=None,
context=None,
y=None,
get_attr='',
**kwargs):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# t = timesteps
assert isinstance(context, dict)
# dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
# t = self.t_embedder(timesteps)
# clip_cls_token = self.cap_embedder(context['vector'])
# clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
# dino_spatial_token = context['crossattn']
# st()
# dino_spatial_token = context['img_crossattn']
# dino_pooled_vector = context['img_vector']
dino_mv_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.
if self.has_caption:
clip_caption_token = context.get('caption_crossattn')
pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector
else:
clip_caption_token = None
# pooled_vector = dino_pooled_vector
pooled_vector = None
# t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector)
t = self.t_embedder(timesteps)
t0 = self.adaLN_modulation(t) # single-adaLN, B 6144
x = self.x_embedder(x)
# add a norm layer here, as in point-e
# x = self.ln_pre(x)
for blk_idx, block in enumerate(self.blocks):
# x = block(x, t0, dino_spatial_token=dino_spatial_token, dino_mv_spatial_token=dino_mv_spatial_token)
x = block(x, t0, dino_mv_spatial_token=dino_mv_spatial_token)
# add a norm layer here, as in point-e
# x = self.ln_post(x)
# todo later
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = x.to(torch.float32).contiguous()
return x
#################################################################################
# DiT_I23D Configs #
#################################################################################
def DiT_XL_2(**kwargs):
return DiT_I23D(depth=28,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_L_2(**kwargs):
return DiT_I23D(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_B_2(**kwargs):
return DiT_I23D(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
def DiT_B_1(**kwargs):
return DiT_I23D(depth=12,
hidden_size=768,
patch_size=1,
num_heads=12,
**kwargs)
def DiT_L_Pixelart_2(**kwargs):
return DiT_I23D_PixelArt(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_B_Pixelart_2(**kwargs):
return DiT_I23D_PixelArt(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
def DiT_L_Pixelart_MV_2(**kwargs):
return DiT_I23D_PixelArt_MVCond(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_L_Pixelart_MV_2_noclip(**kwargs):
return DiT_I23D_PixelArt_MVCond_noClip(depth=24,
hidden_size=1024,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_XL_Pixelart_MV_2(**kwargs):
return DiT_I23D_PixelArt_MVCond(depth=28,
hidden_size=1152,
patch_size=2,
num_heads=16,
**kwargs)
def DiT_B_Pixelart_MV_2(**kwargs):
return DiT_I23D_PixelArt_MVCond(depth=12,
hidden_size=768,
patch_size=2,
num_heads=12,
**kwargs)
# pcd latent
def DiT_L_Pixelart_MV_pcd(**kwargs):
return DiT_pcd_I23D_PixelArt_MVCond(depth=24,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
**kwargs)
# raw gs i23d
def DiT_L_Pixelart_pcd(**kwargs):
return DiT_I23D_PCD_PixelArt(depth=24,
# return DiT_I23D_PCD_PixelArt_noclip(depth=24,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
**kwargs)
def DiT_L_Pixelart_clay_pcd(**kwargs):
return DiT_I23D_PCD_PixelArt_noclip(depth=24,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM,
use_clay_ca=True,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
enable_rope=False,
**kwargs)
def DiT_XL_Pixelart_clay_pcd(**kwargs):
return DiT_I23D_PCD_PixelArt_noclip(depth=28,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM,
use_clay_ca=True,
hidden_size=1152,
patch_size=1, # no spatial compression here
num_heads=16,
enable_rope=False,
**kwargs)
def DiT_B_Pixelart_clay_pcd(**kwargs):
return DiT_I23D_PCD_PixelArt_noclip(depth=12,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM,
use_clay_ca=True,
hidden_size=768,
patch_size=1, # no spatial compression here
num_heads=12,
**kwargs)
def DiT_L_Pixelart_clay_pcd_stage2(**kwargs):
return DiT_I23D_PCD_PixelArt_noclip_clay_stage2(depth=24,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM,
use_clay_ca=True,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
use_pe_cond=True,
**kwargs)
def DiT_B_Pixelart_clay_pcd_stage2(**kwargs):
return DiT_I23D_PCD_PixelArt_noclip_clay_stage2(depth=12,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM,
use_clay_ca=True,
hidden_size=768,
patch_size=1, # no spatial compression here
num_heads=12,
use_pe_cond=True,
**kwargs)
def DiT_L_Pixelart_clay_tandi_pcd(**kwargs):
return DiT_I23D_PCD_PixelArt_noclip(depth=24,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayText,
use_clay_ca=True,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
has_caption=True,
**kwargs)
def DiT_B_Pixelart_clay_tandi_pcd(**kwargs):
return DiT_I23D_PCD_PixelArt_noclip(depth=12,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayText,
use_clay_ca=True,
hidden_size=768,
patch_size=1, # no spatial compression here
num_heads=12,
has_caption=True,
**kwargs)
def DiT_B_Pixelart_pcd(**kwargs):
return DiT_I23D_PCD_PixelArt(depth=12,
hidden_size=768,
patch_size=1, # no spatial compression here
num_heads=12,
**kwargs)
def DiT_B_Pixelart_pcd_cond_diff(**kwargs):
return DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(depth=12,
hidden_size=768,
patch_size=1, # no spatial compression here
num_heads=12,
**kwargs)
def DiT_B_Pixelart_pcd_cond_diff_pe(**kwargs):
return DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(depth=12,
hidden_size=768,
patch_size=1, # no spatial compression here
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM,
num_heads=12,
use_pe_cond=True,
**kwargs)
def DiT_L_Pixelart_pcd_cond_diff_pe(**kwargs):
return DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(depth=24,
hidden_size=1024,
patch_size=1, # no spatial compression here
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM,
num_heads=16,
use_pe_cond=True,
**kwargs)
# mv version
def DiT_L_Pixelart_clay_mv_pcd(**kwargs):
return DiT_pcd_I23D_PixelArt_MVCond_clay(depth=24,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayText,
use_clay_ca=True,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
**kwargs)
def DiT_L_Pixelart_clay_mv_i23dpt_pcd(**kwargs):
return DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt(depth=24,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayMV,
use_clay_ca=True,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
**kwargs)
def DiT_L_Pixelart_clay_mv_i23dpt_pcd_noi23d(**kwargs):
return DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_noi23d(depth=24,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayMV_noi23d,
use_clay_ca=True,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
**kwargs)
def DiT_L_Pixelart_clay_mv_i23dpt_pcd_stage2(**kwargs):
return DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_stage2(depth=24,
vit_blk=ImageCondDiTBlockPixelArtRMSNormClayMV,
use_clay_ca=True,
hidden_size=1024,
patch_size=1, # no spatial compression here
num_heads=16,
**kwargs)
DiT_models = {
'DiT-XL/2': DiT_XL_2,
'DiT-L/2': DiT_L_2,
'DiT-B/2': DiT_B_2,
'DiT-B/1': DiT_B_1,
'DiT-PixArt-L/2': DiT_L_Pixelart_2,
'DiT-PixArt-MV-XL/2': DiT_XL_Pixelart_MV_2,
# 'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2,
'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2_noclip,
'DiT-PixArt-MV-PCD-L': DiT_L_Pixelart_MV_pcd,
# raw xyz cond
'DiT-PixArt-PCD-L': DiT_L_Pixelart_pcd,
'DiT-PixArt-PCD-CLAY-XL': DiT_XL_Pixelart_clay_pcd,
'DiT-PixArt-PCD-CLAY-L': DiT_L_Pixelart_clay_pcd,
'DiT-PixArt-PCD-CLAY-B': DiT_B_Pixelart_clay_pcd,
'DiT-PixArt-PCD-CLAY-stage2-B': DiT_B_Pixelart_clay_pcd_stage2,
'DiT-PixArt-PCD-CLAY-stage2-L': DiT_L_Pixelart_clay_pcd_stage2,
'DiT-PixArt-PCD-CLAY-TandI-L': DiT_L_Pixelart_clay_tandi_pcd,
'DiT-PixArt-PCD-CLAY-TandI-B': DiT_B_Pixelart_clay_tandi_pcd,
'DiT-PixArt-PCD-B': DiT_B_Pixelart_pcd,
# xyz-conditioned KL feature diffusion
'DiT-PixArt-PCD-cond-diff-B': DiT_B_Pixelart_pcd_cond_diff,
'DiT-PixArt-PCD-cond-diff-pe-B': DiT_B_Pixelart_pcd_cond_diff_pe,
'DiT-PixArt-PCD-cond-diff-pe-L': DiT_L_Pixelart_pcd_cond_diff_pe,
'DiT-PixArt-MV-B/2': DiT_B_Pixelart_MV_2,
'DiT-PixArt-B/2': DiT_B_Pixelart_2,
# ! mv version following clay
'DiT-PixArt-PCD-MV-L': DiT_L_Pixelart_clay_mv_pcd,
'DiT-PixArt-PCD-MV-I23Dpt-L': DiT_L_Pixelart_clay_mv_i23dpt_pcd,
'DiT-PixArt-PCD-MV-I23Dpt-L-noI23D': DiT_L_Pixelart_clay_mv_i23dpt_pcd_noi23d,
'DiT-PixArt-PCD-MV-I23Dpt-L-stage2': DiT_L_Pixelart_clay_mv_i23dpt_pcd_stage2,
}