T2V-Turbo / utils /common_utils.py
Ji4chenLi
initial test
9f200a2
raw
history blame
13.5 kB
import ast
import gc
import torch
from collections import OrderedDict
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.models.attention import BasicTransformerBlock
import wandb
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def is_attn(name):
return "attn1" or "attn2" == name.split(".")[-1]
def set_processors(attentions):
for attn in attentions:
attn.set_processor(AttnProcessor2_0())
def set_torch_2_attn(unet):
optim_count = 0
for name, module in unet.named_modules():
if is_attn(name):
if isinstance(module, torch.nn.ModuleList):
for m in module:
if isinstance(m, BasicTransformerBlock):
set_processors([m.attn1, m.attn2])
optim_count += 1
if optim_count > 0:
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")
# From LatentConsistencyModel.get_guidance_scale_embedding
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out
# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(
model_output, timesteps, sample, prediction_type, alphas, sigmas
):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(
model_output, timesteps, sample, prediction_type, alphas, sigmas
):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
# From LatentConsistencyModel.get_guidance_scale_embedding
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out
# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(
model_output, timesteps, sample, prediction_type, alphas, sigmas
):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_x_0
# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(
model_output, timesteps, sample, prediction_type, alphas, sigmas
):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)
return pred_epsilon
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
extra_params = extra_params if len(extra_params.keys()) > 0 else None
return {
"model": model,
"condition": condition,
"extra_params": extra_params,
"is_lora": is_lora,
"negation": negation,
}
def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None):
params = {"name": name, "params": params, "lr": lr}
if extra_params is not None:
for k, v in extra_params.items():
params[k] = v
return params
def create_optimizer_params(model_list, lr):
import itertools
optimizer_params = []
for optim in model_list:
model, condition, extra_params, is_lora, negation = optim.values()
# Check if we are doing LoRA training.
if is_lora and condition and isinstance(model, list):
params = create_optim_params(
params=itertools.chain(*model), extra_params=extra_params
)
optimizer_params.append(params)
continue
if is_lora and condition and not isinstance(model, list):
for n, p in model.named_parameters():
if "lora" in n:
params = create_optim_params(n, p, lr, extra_params)
optimizer_params.append(params)
continue
# If this is true, we can train it.
if condition:
for n, p in model.named_parameters():
should_negate = "lora" in n and not is_lora
if should_negate:
continue
params = create_optim_params(n, p, lr, extra_params)
optimizer_params.append(params)
return optimizer_params
def handle_trainable_modules(
model, trainable_modules=None, is_enabled=True, negation=None
):
acc = []
unfrozen_params = 0
if trainable_modules is not None:
unlock_all = any([name == "all" for name in trainable_modules])
if unlock_all:
model.requires_grad_(True)
unfrozen_params = len(list(model.parameters()))
else:
model.requires_grad_(False)
for name, param in model.named_parameters():
for tm in trainable_modules:
if all([tm in name, name not in acc, "lora" not in name]):
param.requires_grad_(is_enabled)
acc.append(name)
unfrozen_params += 1
def huber_loss(pred, target, huber_c=0.001):
loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c
return loss.mean()
@torch.no_grad()
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def log_validation_video(pipeline, args, accelerator, save_fps):
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
validation_prompts = [
"An astronaut riding a horse.",
"Darth vader surfing in waves.",
"Robot dancing in times square.",
"Clown fish swimming through the coral reef.",
"A child excitedly swings on a rusty swing set, laughter filling the air.",
"With the style of van gogh, A young couple dances under the moonlight by the lake.",
"A young woman with glasses is jogging in the park wearing a pink headband.",
"Impressionist style, a yellow rubber duck floating on the wave on the sunset",
]
video_logs = []
for _, prompt in enumerate(validation_prompts):
with torch.autocast("cuda"):
videos = pipeline(
prompt=prompt,
frames=args.n_frames,
num_inference_steps=4,
num_videos_per_prompt=2,
generator=generator,
)
videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0
videos = (videos * 255).to(torch.uint8).permute(0, 2, 1, 3, 4).cpu().numpy()
video_logs.append({"validation_prompt": prompt, "videos": videos})
for tracker in accelerator.trackers:
if tracker.name == "wandb":
formatted_videos = []
for log in video_logs:
videos = log["videos"]
validation_prompt = log["validation_prompt"]
for video in videos:
video = wandb.Video(video, caption=validation_prompt, fps=save_fps)
formatted_videos.append(video)
tracker.log({f"validation": formatted_videos})
del pipeline
gc.collect()
def tuple_type(s):
if isinstance(s, tuple):
return s
value = ast.literal_eval(s)
if isinstance(value, tuple):
return value
raise TypeError("Argument must be a tuple")
def load_model_checkpoint(model, ckpt):
def load_checkpoint(model, ckpt, full_strict):
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=full_strict)
del state_dict
gc.collect()
return model
load_checkpoint(model, ckpt, full_strict=True)
print(">>> model checkpoint loaded.")
return model