|
from typing import Literal, Union, Optional, Tuple, List |
|
|
|
import torch |
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
|
from diffusers import ( |
|
UNet2DConditionModel, |
|
SchedulerMixin, |
|
StableDiffusionPipeline, |
|
StableDiffusionXLPipeline, |
|
AutoencoderKL, |
|
) |
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( |
|
convert_ldm_unet_checkpoint, |
|
) |
|
from safetensors.torch import load_file |
|
from diffusers.schedulers import ( |
|
DDIMScheduler, |
|
DDPMScheduler, |
|
LMSDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
UniPCMultistepScheduler, |
|
) |
|
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
NUM_TRAIN_TIMESTEPS = 1000 |
|
BETA_START = 0.00085 |
|
BETA_END = 0.0120 |
|
|
|
UNET_PARAMS_MODEL_CHANNELS = 320 |
|
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] |
|
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] |
|
UNET_PARAMS_IMAGE_SIZE = 64 |
|
UNET_PARAMS_IN_CHANNELS = 4 |
|
UNET_PARAMS_OUT_CHANNELS = 4 |
|
UNET_PARAMS_NUM_RES_BLOCKS = 2 |
|
UNET_PARAMS_CONTEXT_DIM = 768 |
|
UNET_PARAMS_NUM_HEADS = 8 |
|
|
|
|
|
VAE_PARAMS_Z_CHANNELS = 4 |
|
VAE_PARAMS_RESOLUTION = 256 |
|
VAE_PARAMS_IN_CHANNELS = 3 |
|
VAE_PARAMS_OUT_CH = 3 |
|
VAE_PARAMS_CH = 128 |
|
VAE_PARAMS_CH_MULT = [1, 2, 4, 4] |
|
VAE_PARAMS_NUM_RES_BLOCKS = 2 |
|
|
|
|
|
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] |
|
V2_UNET_PARAMS_CONTEXT_DIM = 1024 |
|
|
|
|
|
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" |
|
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" |
|
|
|
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "euler", "uniPC"] |
|
|
|
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] |
|
|
|
DIFFUSERS_CACHE_DIR = None |
|
|
|
|
|
def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"): |
|
|
|
TEXT_ENCODER_KEY_REPLACEMENTS = [ |
|
( |
|
"cond_stage_model.transformer.embeddings.", |
|
"cond_stage_model.transformer.text_model.embeddings.", |
|
), |
|
( |
|
"cond_stage_model.transformer.encoder.", |
|
"cond_stage_model.transformer.text_model.encoder.", |
|
), |
|
( |
|
"cond_stage_model.transformer.final_layer_norm.", |
|
"cond_stage_model.transformer.text_model.final_layer_norm.", |
|
), |
|
] |
|
|
|
if ckpt_path.endswith(".safetensors"): |
|
checkpoint = None |
|
state_dict = load_file(ckpt_path) |
|
else: |
|
checkpoint = torch.load(ckpt_path, map_location=device) |
|
if "state_dict" in checkpoint: |
|
state_dict = checkpoint["state_dict"] |
|
else: |
|
state_dict = checkpoint |
|
checkpoint = None |
|
|
|
key_reps = [] |
|
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: |
|
for key in state_dict.keys(): |
|
if key.startswith(rep_from): |
|
new_key = rep_to + key[len(rep_from) :] |
|
key_reps.append((key, new_key)) |
|
|
|
for key, new_key in key_reps: |
|
state_dict[new_key] = state_dict[key] |
|
del state_dict[key] |
|
|
|
return checkpoint, state_dict |
|
|
|
|
|
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): |
|
""" |
|
Creates a config for the diffusers based on the config of the LDM model. |
|
""" |
|
|
|
|
|
block_out_channels = [ |
|
UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT |
|
] |
|
|
|
down_block_types = [] |
|
resolution = 1 |
|
for i in range(len(block_out_channels)): |
|
block_type = ( |
|
"CrossAttnDownBlock2D" |
|
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS |
|
else "DownBlock2D" |
|
) |
|
down_block_types.append(block_type) |
|
if i != len(block_out_channels) - 1: |
|
resolution *= 2 |
|
|
|
up_block_types = [] |
|
for i in range(len(block_out_channels)): |
|
block_type = ( |
|
"CrossAttnUpBlock2D" |
|
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS |
|
else "UpBlock2D" |
|
) |
|
up_block_types.append(block_type) |
|
resolution //= 2 |
|
|
|
config = dict( |
|
sample_size=UNET_PARAMS_IMAGE_SIZE, |
|
in_channels=UNET_PARAMS_IN_CHANNELS, |
|
out_channels=UNET_PARAMS_OUT_CHANNELS, |
|
down_block_types=tuple(down_block_types), |
|
up_block_types=tuple(up_block_types), |
|
block_out_channels=tuple(block_out_channels), |
|
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, |
|
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM |
|
if not v2 |
|
else V2_UNET_PARAMS_CONTEXT_DIM, |
|
attention_head_dim=UNET_PARAMS_NUM_HEADS |
|
if not v2 |
|
else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, |
|
|
|
) |
|
if v2 and use_linear_projection_in_v2: |
|
config["use_linear_projection"] = True |
|
|
|
return config |
|
|
|
|
|
def load_diffusers_model( |
|
pretrained_model_name_or_path: str, |
|
v2: bool = False, |
|
clip_skip: Optional[int] = None, |
|
weight_dtype: torch.dtype = torch.float32, |
|
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: |
|
if v2: |
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
TOKENIZER_V2_MODEL_NAME, |
|
subfolder="tokenizer", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="text_encoder", |
|
|
|
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
else: |
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
TOKENIZER_V1_MODEL_NAME, |
|
subfolder="tokenizer", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="text_encoder", |
|
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="unet", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
|
|
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
|
|
|
return tokenizer, text_encoder, unet, vae |
|
|
|
|
|
def load_checkpoint_model( |
|
checkpoint_path: str, |
|
v2: bool = False, |
|
clip_skip: Optional[int] = None, |
|
weight_dtype: torch.dtype = torch.float32, |
|
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: |
|
pipe = StableDiffusionPipeline.from_single_file( |
|
checkpoint_path, |
|
upcast_attention=True if v2 else False, |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
|
|
_, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path) |
|
unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2) |
|
unet_config["class_embed_type"] = None |
|
unet_config["addition_embed_type"] = None |
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) |
|
unet = UNet2DConditionModel(**unet_config) |
|
unet.load_state_dict(converted_unet_checkpoint) |
|
|
|
tokenizer = pipe.tokenizer |
|
text_encoder = pipe.text_encoder |
|
vae = pipe.vae |
|
if clip_skip is not None: |
|
if v2: |
|
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) |
|
else: |
|
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) |
|
|
|
del pipe |
|
|
|
return tokenizer, text_encoder, unet, vae |
|
|
|
|
|
def load_models( |
|
pretrained_model_name_or_path: str, |
|
scheduler_name: str, |
|
v2: bool = False, |
|
v_pred: bool = False, |
|
weight_dtype: torch.dtype = torch.float32, |
|
) -> Tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: |
|
if pretrained_model_name_or_path.endswith( |
|
".ckpt" |
|
) or pretrained_model_name_or_path.endswith(".safetensors"): |
|
tokenizer, text_encoder, unet, vae = load_checkpoint_model( |
|
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype |
|
) |
|
else: |
|
tokenizer, text_encoder, unet, vae = load_diffusers_model( |
|
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype |
|
) |
|
|
|
if scheduler_name: |
|
scheduler = create_noise_scheduler( |
|
scheduler_name, |
|
prediction_type="v_prediction" if v_pred else "epsilon", |
|
) |
|
else: |
|
scheduler = None |
|
|
|
return tokenizer, text_encoder, unet, scheduler, vae |
|
|
|
|
|
def load_diffusers_model_xl( |
|
pretrained_model_name_or_path: str, |
|
weight_dtype: torch.dtype = torch.float32, |
|
) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: |
|
|
|
|
|
tokenizers = [ |
|
CLIPTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="tokenizer", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
), |
|
CLIPTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="tokenizer_2", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
pad_token_id=0, |
|
), |
|
] |
|
|
|
text_encoders = [ |
|
CLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="text_encoder", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
), |
|
CLIPTextModelWithProjection.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="text_encoder_2", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
), |
|
] |
|
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
pretrained_model_name_or_path, |
|
subfolder="unet", |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") |
|
return tokenizers, text_encoders, unet, vae |
|
|
|
|
|
def load_checkpoint_model_xl( |
|
checkpoint_path: str, |
|
weight_dtype: torch.dtype = torch.float32, |
|
) -> Tuple[List[CLIPTokenizer], List[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: |
|
pipe = StableDiffusionXLPipeline.from_single_file( |
|
checkpoint_path, |
|
torch_dtype=weight_dtype, |
|
cache_dir=DIFFUSERS_CACHE_DIR, |
|
) |
|
|
|
unet = pipe.unet |
|
vae = pipe.vae |
|
tokenizers = [pipe.tokenizer, pipe.tokenizer_2] |
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] |
|
if len(text_encoders) == 2: |
|
text_encoders[1].pad_token_id = 0 |
|
|
|
del pipe |
|
|
|
return tokenizers, text_encoders, unet, vae |
|
|
|
|
|
def load_models_xl( |
|
pretrained_model_name_or_path: str, |
|
scheduler_name: str, |
|
weight_dtype: torch.dtype = torch.float32, |
|
noise_scheduler_kwargs=None, |
|
) -> Tuple[ |
|
List[CLIPTokenizer], |
|
List[SDXL_TEXT_ENCODER_TYPE], |
|
UNet2DConditionModel, |
|
SchedulerMixin, |
|
]: |
|
if pretrained_model_name_or_path.endswith( |
|
".ckpt" |
|
) or pretrained_model_name_or_path.endswith(".safetensors"): |
|
(tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl( |
|
pretrained_model_name_or_path, weight_dtype |
|
) |
|
else: |
|
(tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl( |
|
pretrained_model_name_or_path, weight_dtype |
|
) |
|
if scheduler_name: |
|
scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs) |
|
else: |
|
scheduler = None |
|
|
|
return tokenizers, text_encoders, unet, scheduler, vae |
|
|
|
def create_noise_scheduler( |
|
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", |
|
noise_scheduler_kwargs=None, |
|
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", |
|
) -> SchedulerMixin: |
|
name = scheduler_name.lower().replace(" ", "_") |
|
if name.lower() == "ddim": |
|
|
|
scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) |
|
elif name.lower() == "ddpm": |
|
|
|
scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) |
|
elif name.lower() == "lms": |
|
|
|
scheduler = LMSDiscreteScheduler( |
|
**OmegaConf.to_container(noise_scheduler_kwargs) |
|
) |
|
elif name.lower() == "euler_a": |
|
|
|
scheduler = EulerAncestralDiscreteScheduler( |
|
**OmegaConf.to_container(noise_scheduler_kwargs) |
|
) |
|
elif name.lower() == "euler": |
|
|
|
scheduler = EulerDiscreteScheduler( |
|
**OmegaConf.to_container(noise_scheduler_kwargs) |
|
) |
|
elif name.lower() == "unipc": |
|
|
|
scheduler = UniPCMultistepScheduler( |
|
**OmegaConf.to_container(noise_scheduler_kwargs) |
|
) |
|
else: |
|
raise ValueError(f"Unknown scheduler name: {name}") |
|
|
|
return scheduler |
|
|
|
|
|
def torch_gc(): |
|
import gc |
|
|
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
with torch.cuda.device("cuda"): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
|
|
from enum import Enum |
|
|
|
|
|
class CPUState(Enum): |
|
GPU = 0 |
|
CPU = 1 |
|
MPS = 2 |
|
|
|
|
|
cpu_state = CPUState.GPU |
|
xpu_available = False |
|
directml_enabled = False |
|
|
|
|
|
def is_intel_xpu(): |
|
global cpu_state |
|
global xpu_available |
|
if cpu_state == CPUState.GPU: |
|
if xpu_available: |
|
return True |
|
return False |
|
|
|
|
|
try: |
|
import intel_extension_for_pytorch as ipex |
|
|
|
if torch.xpu.is_available(): |
|
xpu_available = True |
|
except: |
|
pass |
|
|
|
try: |
|
if torch.backends.mps.is_available(): |
|
cpu_state = CPUState.MPS |
|
import torch.mps |
|
except: |
|
pass |
|
|
|
|
|
def get_torch_device(): |
|
global directml_enabled |
|
global cpu_state |
|
if directml_enabled: |
|
global directml_device |
|
return directml_device |
|
if cpu_state == CPUState.MPS: |
|
return torch.device("mps") |
|
if cpu_state == CPUState.CPU: |
|
return torch.device("cpu") |
|
else: |
|
if is_intel_xpu(): |
|
return torch.device("xpu") |
|
else: |
|
return torch.device(torch.cuda.current_device()) |