|
import torch |
|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class STDiTConfig(PretrainedConfig): |
|
|
|
model_type = "stdit" |
|
|
|
def __init__( |
|
self, |
|
input_size=(1, 32, 32), |
|
in_channels=4, |
|
patch_size=(1, 2, 2), |
|
hidden_size=1152, |
|
depth=28, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
class_dropout_prob=0.1, |
|
pred_sigma=True, |
|
drop_path=0.0, |
|
no_temporal_pos_emb=False, |
|
caption_channels=4096, |
|
model_max_length=120, |
|
space_scale=1.0, |
|
time_scale=1.0, |
|
freeze=None, |
|
enable_flash_attn=False, |
|
enable_layernorm_kernel=False, |
|
enable_sequence_parallelism=False, |
|
**kwargs, |
|
): |
|
self.input_size = input_size |
|
self.in_channels = in_channels |
|
self.patch_size = patch_size |
|
self.hidden_size = hidden_size |
|
self.depth = depth |
|
self.num_heads = num_heads |
|
self.mlp_ratio = mlp_ratio |
|
self.class_dropout_prob = class_dropout_prob |
|
self.pred_sigma = pred_sigma |
|
self.drop_path = drop_path |
|
self.no_temporal_pos_emb = no_temporal_pos_emb |
|
self.caption_channels = caption_channels |
|
self.model_max_length = model_max_length |
|
self.space_scale = space_scale |
|
self.time_scale = time_scale |
|
self.freeze = freeze |
|
self.enable_flash_attn = enable_flash_attn |
|
self.enable_layernorm_kernel = enable_layernorm_kernel |
|
self.enable_sequence_parallelism = enable_sequence_parallelism |
|
super().__init__(**kwargs) |