|
import torch.nn as nn |
|
|
|
from .clip import FrozenCLIPEmbedder |
|
from .switti import Switti |
|
from .vqvae import VQVAE |
|
from .pipeline import SwittiPipeline |
|
|
|
|
|
def build_models( |
|
|
|
device, |
|
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), |
|
|
|
V=4096, |
|
Cvae=32, |
|
ch=160, |
|
share_quant_resi=4, |
|
|
|
depth=16, |
|
rope=True, |
|
rope_theta=10000, |
|
rope_size=128, |
|
use_swiglu_ffn=True, |
|
use_ar=False, |
|
use_crop_cond=True, |
|
attn_l2_norm=True, |
|
init_adaln=0.5, |
|
init_adaln_gamma=1e-5, |
|
init_head=0.02, |
|
init_std=-1, |
|
drop_rate=0.0, |
|
attn_drop_rate=0.0, |
|
dpr=0, |
|
norm_eps=1e-6, |
|
|
|
text_encoder_path="openai/clip-vit-large-patch14", |
|
text_encoder_2_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", |
|
) -> tuple[VQVAE, Switti]: |
|
heads = depth |
|
width = depth * 64 |
|
if dpr > 0: |
|
dpr = dpr * depth / 24 |
|
|
|
|
|
for clz in ( |
|
nn.Linear, |
|
nn.LayerNorm, |
|
nn.BatchNorm2d, |
|
nn.SyncBatchNorm, |
|
nn.Conv1d, |
|
nn.Conv2d, |
|
nn.ConvTranspose1d, |
|
nn.ConvTranspose2d, |
|
): |
|
setattr(clz, "reset_parameters", lambda self: None) |
|
|
|
|
|
vae_local = VQVAE( |
|
vocab_size=V, |
|
z_channels=Cvae, |
|
ch=ch, |
|
test_mode=True, |
|
share_quant_resi=share_quant_resi, |
|
v_patch_nums=patch_nums, |
|
).to(device) |
|
|
|
switti_wo_ddp = Switti( |
|
depth=depth, |
|
embed_dim=width, |
|
num_heads=heads, |
|
drop_rate=drop_rate, |
|
attn_drop_rate=attn_drop_rate, |
|
drop_path_rate=dpr, |
|
norm_eps=norm_eps, |
|
attn_l2_norm=attn_l2_norm, |
|
patch_nums=patch_nums, |
|
rope=rope, |
|
rope_theta=rope_theta, |
|
rope_size=rope_size, |
|
use_swiglu_ffn=use_swiglu_ffn, |
|
use_ar=use_ar, |
|
use_crop_cond=use_crop_cond, |
|
).to(device) |
|
|
|
switti_wo_ddp.init_weights( |
|
init_adaln=init_adaln, |
|
init_adaln_gamma=init_adaln_gamma, |
|
init_head=init_head, |
|
init_std=init_std, |
|
) |
|
text_encoder = FrozenCLIPEmbedder(text_encoder_path) |
|
text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path) |
|
pipe = SwittiPipeline(switti_wo_ddp, vae_local, text_encoder, text_encoder_2, device) |
|
|
|
return vae_local, switti_wo_ddp, pipe |
|
|