OpenSora-STDiT-v1-HQ-16x512x512 / configuration_stdit.py
frankleeeee's picture
Upload STDiT
9637da1 verified
raw
history blame
1.59 kB
import torch
from transformers import PretrainedConfig
from typing import List
class STDiTConfig(PretrainedConfig):
model_type = "stdit"
def __init__(
self,
input_size=(1, 32, 32),
in_channels=4,
patch_size=(1, 2, 2),
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path=0.0,
no_temporal_pos_emb=False,
caption_channels=4096,
model_max_length=120,
space_scale=1.0,
time_scale=1.0,
freeze=None,
enable_flash_attn=False,
enable_layernorm_kernel=False,
enable_sequence_parallelism=False,
**kwargs,
):
self.input_size = input_size
self.in_channels = in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.class_dropout_prob = class_dropout_prob
self.pred_sigma = pred_sigma
self.drop_path = drop_path
self.no_temporal_pos_emb = no_temporal_pos_emb
self.caption_channels = caption_channels
self.model_max_length = model_max_length
self.space_scale = space_scale
self.time_scale = time_scale
self.freeze = freeze
self.enable_flash_attn = enable_flash_attn
self.enable_layernorm_kernel = enable_layernorm_kernel
self.enable_sequence_parallelism = enable_sequence_parallelism
super().__init__(**kwargs)