multimodalart's picture
Upload 52 files
18d0601 verified
# Bootstrapped from Huggingface diffuser's code.
import fnmatch
import json
import math
import os
import shutil
from typing import List, Optional
import numpy as np
import torch
import torch.utils.checkpoint
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.optimization import get_scheduler
from safetensors.torch import save_file
from tqdm.auto import tqdm
from dataset_and_utils import (
PreprocessedDataset,
TokenEmbeddingsHandler,
load_models,
unet_attn_processors_state_dict,
)
def main(
pretrained_model_name_or_path: Optional[
str
] = "./cache", # "stabilityai/stable-diffusion-xl-base-1.0",
revision: Optional[str] = None,
instance_data_dir: Optional[str] = "./dataset/zeke/captions.csv",
output_dir: str = "ft_masked_coke",
seed: Optional[int] = 42,
resolution: int = 512,
crops_coords_top_left_h: int = 0,
crops_coords_top_left_w: int = 0,
train_batch_size: int = 1,
do_cache: bool = True,
num_train_epochs: int = 600,
max_train_steps: Optional[int] = None,
checkpointing_steps: int = 500000, # default to no checkpoints
gradient_accumulation_steps: int = 1, # todo
unet_learning_rate: float = 1e-5,
ti_lr: float = 3e-4,
lora_lr: float = 1e-4,
pivot_halfway: bool = True,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_warmup_steps: int = 500,
lr_num_cycles: int = 1,
lr_power: float = 1.0,
dataloader_num_workers: int = 0,
max_grad_norm: float = 1.0, # todo with tests
allow_tf32: bool = True,
mixed_precision: Optional[str] = "bf16",
device: str = "cuda:0",
token_dict: dict = {"TOKEN": "<s0>"},
inserting_list_tokens: List[str] = ["<s0>"],
verbose: bool = True,
is_lora: bool = True,
lora_rank: int = 32,
) -> None:
if allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if not seed:
seed = np.random.randint(0, 2**32 - 1)
print("Using seed", seed)
torch.manual_seed(seed)
weight_dtype = torch.float32
if mixed_precision == "fp16":
weight_dtype = torch.float16
elif mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if scale_lr:
unet_learning_rate = (
unet_learning_rate * gradient_accumulation_steps * train_batch_size
)
(
tokenizer_one,
tokenizer_two,
noise_scheduler,
text_encoder_one,
text_encoder_two,
vae,
unet,
) = load_models(pretrained_model_name_or_path, revision, device, weight_dtype)
print("# PTI : Loaded models")
# Initialize new tokens for training.
embedding_handler = TokenEmbeddingsHandler(
[text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
)
embedding_handler.initialize_new_tokens(inserting_toks=inserting_list_tokens)
text_encoders = [text_encoder_one, text_encoder_two]
unet_param_to_optimize = []
# fine tune only attn weights
text_encoder_parameters = []
for text_encoder in text_encoders:
for name, param in text_encoder.named_parameters():
if "token_embedding" in name:
param.requires_grad = True
print(name)
text_encoder_parameters.append(param)
else:
param.requires_grad = False
if not is_lora:
WHITELIST_PATTERNS = [
# "*.attn*.weight",
# "*ff*.weight",
"*"
] # TODO : make this a parameter
BLACKLIST_PATTERNS = ["*.norm*.weight", "*time*"]
unet_param_to_optimize_names = []
for name, param in unet.named_parameters():
if any(
fnmatch.fnmatch(name, pattern) for pattern in WHITELIST_PATTERNS
) and not any(
fnmatch.fnmatch(name, pattern) for pattern in BLACKLIST_PATTERNS
):
param.requires_grad_(True)
unet_param_to_optimize_names.append(name)
print(f"Training: {name}")
else:
param.requires_grad_(False)
# Optimizer creation
params_to_optimize = [
{
"params": unet_param_to_optimize,
"lr": unet_learning_rate,
},
{
"params": text_encoder_parameters,
"lr": ti_lr,
"weight_decay": 1e-3,
},
]
else:
# Do lora-training instead.
unet.requires_grad_(False)
unet_lora_attn_procs = {}
unet_lora_parameters = []
for name, attn_processor in unet.attn_processors.items():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
module = LoRAAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=lora_rank,
)
unet_lora_attn_procs[name] = module
module.to(device)
unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs)
params_to_optimize = [
{
"params": unet_lora_parameters,
"lr": lora_lr,
},
{
"params": text_encoder_parameters,
"lr": ti_lr,
"weight_decay": 1e-3,
},
]
optimizer = torch.optim.AdamW(
params_to_optimize,
weight_decay=1e-4,
)
print(f"# PTI : Loading dataset, do_cache {do_cache}")
train_dataset = PreprocessedDataset(
instance_data_dir,
tokenizer_one,
tokenizer_two,
vae.float(),
do_cache=True,
substitute_caption_map=token_dict,
)
print("# PTI : Loaded dataset")
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=dataloader_num_workers,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
if max_train_steps is None:
max_train_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler(
lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
num_training_steps=max_train_steps * gradient_accumulation_steps,
num_cycles=lr_num_cycles,
power=lr_power,
)
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / gradient_accumulation_steps
)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
total_batch_size = train_batch_size * gradient_accumulation_steps
if verbose:
print(f"# PTI : Running training ")
print(f"# PTI : Num examples = {len(train_dataset)}")
print(f"# PTI : Num batches each epoch = {len(train_dataloader)}")
print(f"# PTI : Num Epochs = {num_train_epochs}")
print(f"# PTI : Instantaneous batch size per device = {train_batch_size}")
print(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
print(f"# PTI : Gradient Accumulation steps = {gradient_accumulation_steps}")
print(f"# PTI : Total optimization steps = {max_train_steps}")
global_step = 0
first_epoch = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps))
checkpoint_dir = "checkpoint"
if os.path.exists(checkpoint_dir):
shutil.rmtree(checkpoint_dir)
os.makedirs(f"{checkpoint_dir}/unet", exist_ok=True)
os.makedirs(f"{checkpoint_dir}/embeddings", exist_ok=True)
for epoch in range(first_epoch, num_train_epochs):
if pivot_halfway:
if epoch == num_train_epochs // 2:
print("# PTI : Pivot halfway")
# remove text encoder parameters from optimizer
params_to_optimize = params_to_optimize[:1]
optimizer = torch.optim.AdamW(
params_to_optimize,
weight_decay=1e-4,
)
unet.train()
for step, batch in enumerate(train_dataloader):
progress_bar.update(1)
progress_bar.set_description(f"# PTI :step: {global_step}, epoch: {epoch}")
global_step += 1
(tok1, tok2), vae_latent, mask = batch
vae_latent = vae_latent.to(weight_dtype)
# tokens to text embeds
prompt_embeds_list = []
for tok, text_encoder in zip((tok1, tok2), text_encoders):
prompt_embeds_out = text_encoder(
tok.to(text_encoder.device),
output_hidden_states=True,
)
pooled_prompt_embeds = prompt_embeds_out[0]
prompt_embeds = prompt_embeds_out.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
# Create Spatial-dimensional conditions.
original_size = (resolution, resolution)
target_size = (resolution, resolution)
crops_coords_top_left = (crops_coords_top_left_h, crops_coords_top_left_w)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype).repeat(
bs_embed, 1
)
added_kw = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
# Sample noise that we'll add to the latents
noise = torch.randn_like(vae_latent)
bsz = vae_latent.shape[0]
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(bsz,),
device=vae_latent.device,
)
timesteps = timesteps.long()
noisy_model_input = noise_scheduler.add_noise(vae_latent, noise, timesteps)
# Predict the noise residual
model_pred = unet(
noisy_model_input,
timesteps,
prompt_embeds,
added_cond_kwargs=added_kw,
).sample
loss = (model_pred - noise).pow(2) * mask
loss = loss.mean()
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# every step, we reset the embeddings to the original embeddings.
for idx, text_encoder in enumerate(text_encoders):
embedding_handler.retract_embeddings()
if global_step % checkpointing_steps == 0:
# save the required params of unet with safetensor
if not is_lora:
tensors = {
name: param
for name, param in unet.named_parameters()
if name in unet_param_to_optimize_names
}
save_file(
tensors,
f"{checkpoint_dir}/unet/checkpoint-{global_step}.unet.safetensors",
)
else:
lora_tensors = unet_attn_processors_state_dict(unet)
save_file(
lora_tensors,
f"{checkpoint_dir}/unet/checkpoint-{global_step}.lora.safetensors",
)
embedding_handler.save_embeddings(
f"{checkpoint_dir}/embeddings/checkpoint-{global_step}.pti",
)
# final_save
print("Saving final model for return")
if not is_lora:
tensors = {
name: param
for name, param in unet.named_parameters()
if name in unet_param_to_optimize_names
}
save_file(
tensors,
f"{output_dir}/unet.safetensors",
)
else:
lora_tensors = unet_attn_processors_state_dict(unet)
save_file(
lora_tensors,
f"{output_dir}/lora.safetensors",
)
embedding_handler.save_embeddings(
f"{output_dir}/embeddings.pti",
)
to_save = token_dict
with open(f"{output_dir}/special_params.json", "w") as f:
json.dump(to_save, f)
if __name__ == "__main__":
main()