imw34531's picture
Upload folder using huggingface_hub
87e21d1 verified
import json
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class BaseConfig:
def get(self, attribute_name, default=None):
return getattr(self, attribute_name, default)
def pop(self, attribute_name, default=None):
if hasattr(self, attribute_name):
value = getattr(self, attribute_name)
delattr(self, attribute_name)
return value
else:
return default
def __str__(self):
return json.dumps(asdict(self), indent=4)
@dataclass
class DataConfig(BaseConfig):
data_dir: List[Optional[str]] = field(default_factory=list)
caption_proportion: Dict[str, int] = field(default_factory=lambda: {"prompt": 1})
external_caption_suffixes: List[str] = field(default_factory=list)
external_clipscore_suffixes: List[str] = field(default_factory=list)
clip_thr_temperature: float = 1.0
clip_thr: float = 0.0
sort_dataset: bool = False
load_text_feat: bool = False
load_vae_feat: bool = False
transform: str = "default_train"
type: str = "SanaWebDatasetMS"
image_size: int = 512
hq_only: bool = False
valid_num: int = 0
data: Any = None
extra: Any = None
@dataclass
class ModelConfig(BaseConfig):
model: str = "SanaMS_600M_P1_D28"
image_size: int = 512
mixed_precision: str = "fp16" # ['fp16', 'fp32', 'bf16']
fp32_attention: bool = True
load_from: Optional[str] = None
resume_from: Optional[Dict[str, Any]] = field(
default_factory=lambda: {
"checkpoint": None,
"load_ema": False,
"resume_lr_scheduler": True,
"resume_optimizer": True,
}
)
aspect_ratio_type: str = "ASPECT_RATIO_1024"
multi_scale: bool = True
pe_interpolation: float = 1.0
micro_condition: bool = False
attn_type: str = "linear"
autocast_linear_attn: bool = False
ffn_type: str = "glumbconv"
mlp_acts: List[Optional[str]] = field(default_factory=lambda: ["silu", "silu", None])
mlp_ratio: float = 2.5
use_pe: bool = False
qk_norm: bool = False
class_dropout_prob: float = 0.0
linear_head_dim: int = 32
cross_norm: bool = False
cfg_scale: int = 4
guidance_type: str = "classifier-free"
pag_applied_layers: List[int] = field(default_factory=lambda: [14])
extra: Any = None
@dataclass
class AEConfig(BaseConfig):
vae_type: str = "dc-ae"
vae_pretrained: str = "mit-han-lab/dc-ae-f32c32-sana-1.0"
scale_factor: float = 0.41407
vae_latent_dim: int = 32
vae_downsample_rate: int = 32
sample_posterior: bool = True
extra: Any = None
@dataclass
class TextEncoderConfig(BaseConfig):
text_encoder_name: str = "gemma-2-2b-it"
caption_channels: int = 2304
y_norm: bool = True
y_norm_scale_factor: float = 1.0
model_max_length: int = 300
chi_prompt: List[Optional[str]] = field(default_factory=lambda: [])
extra: Any = None
@dataclass
class SchedulerConfig(BaseConfig):
train_sampling_steps: int = 1000
predict_v: bool = True
noise_schedule: str = "linear_flow"
pred_sigma: bool = False
learn_sigma: bool = True
vis_sampler: str = "flow_dpm-solver"
flow_shift: float = 1.0
# logit-normal timestep
weighting_scheme: Optional[str] = "logit_normal"
logit_mean: float = 0.0
logit_std: float = 1.0
extra: Any = None
@dataclass
class TrainingConfig(BaseConfig):
num_workers: int = 4
seed: int = 43
train_batch_size: int = 32
num_epochs: int = 100
gradient_accumulation_steps: int = 1
grad_checkpointing: bool = False
gradient_clip: float = 1.0
gc_step: int = 1
optimizer: Dict[str, Any] = field(
default_factory=lambda: {"eps": 1.0e-10, "lr": 0.0001, "type": "AdamW", "weight_decay": 0.03}
)
lr_schedule: str = "constant"
lr_schedule_args: Dict[str, int] = field(default_factory=lambda: {"num_warmup_steps": 500})
auto_lr: Dict[str, str] = field(default_factory=lambda: {"rule": "sqrt"})
ema_rate: float = 0.9999
eval_batch_size: int = 16
use_fsdp: bool = False
use_flash_attn: bool = False
eval_sampling_steps: int = 250
lora_rank: int = 4
log_interval: int = 50
mask_type: str = "null"
mask_loss_coef: float = 0.0
load_mask_index: bool = False
snr_loss: bool = False
real_prompt_ratio: float = 1.0
save_image_epochs: int = 1
save_model_epochs: int = 1
save_model_steps: int = 1000000
visualize: bool = False
null_embed_root: str = "output/pretrained_models/"
valid_prompt_embed_root: str = "output/tmp_embed/"
validation_prompts: List[str] = field(
default_factory=lambda: [
"dog",
"portrait photo of a girl, photograph, highly detailed face, depth of field",
"Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]
)
local_save_vis: bool = False
deterministic_validation: bool = True
online_metric: bool = False
eval_metric_step: int = 5000
online_metric_dir: str = "metric_helper"
work_dir: str = "/cache/exps/"
skip_step: int = 0
loss_type: str = "huber"
huber_c: float = 0.001
num_ddim_timesteps: int = 50
w_max: float = 15.0
w_min: float = 3.0
ema_decay: float = 0.95
debug_nan: bool = False
extra: Any = None
@dataclass
class SanaConfig(BaseConfig):
data: DataConfig
model: ModelConfig
vae: AEConfig
text_encoder: TextEncoderConfig
scheduler: SchedulerConfig
train: TrainingConfig
work_dir: str = "output/"
resume_from: Optional[str] = None
load_from: Optional[str] = None
debug: bool = False
caching: bool = False
report_to: str = "wandb"
tracker_project_name: str = "t2i-evit-baseline"
name: str = "baseline"
loss_report_name: str = "loss"