Spaces:
Running
on
Zero
Running
on
Zero
import ast | |
import gc | |
import torch | |
from collections import OrderedDict | |
from diffusers.models.attention_processor import AttnProcessor2_0 | |
from diffusers.models.attention import BasicTransformerBlock | |
import wandb | |
def extract_into_tensor(a, t, x_shape): | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
def is_attn(name): | |
return "attn1" or "attn2" == name.split(".")[-1] | |
def set_processors(attentions): | |
for attn in attentions: | |
attn.set_processor(AttnProcessor2_0()) | |
def set_torch_2_attn(unet): | |
optim_count = 0 | |
for name, module in unet.named_modules(): | |
if is_attn(name): | |
if isinstance(module, torch.nn.ModuleList): | |
for m in module: | |
if isinstance(m, BasicTransformerBlock): | |
set_processors([m.attn1, m.attn2]) | |
optim_count += 1 | |
if optim_count > 0: | |
print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") | |
# From LatentConsistencyModel.get_guidance_scale_embedding | |
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): | |
""" | |
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | |
Args: | |
timesteps (`torch.Tensor`): | |
generate embedding vectors at these timesteps | |
embedding_dim (`int`, *optional*, defaults to 512): | |
dimension of the embeddings to generate | |
dtype: | |
data type of the generated embeddings | |
Returns: | |
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` | |
""" | |
assert len(w.shape) == 1 | |
w = w * 1000.0 | |
half_dim = embedding_dim // 2 | |
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | |
emb = w.to(dtype)[:, None] * emb[None, :] | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
if embedding_dim % 2 == 1: # zero pad | |
emb = torch.nn.functional.pad(emb, (0, 1)) | |
assert emb.shape == (w.shape[0], embedding_dim) | |
return emb | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError( | |
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
) | |
return x[(...,) + (None,) * dims_to_append] | |
# From LCMScheduler.get_scalings_for_boundary_condition_discrete | |
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): | |
scaled_timestep = timestep_scaling * timestep | |
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) | |
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 | |
return c_skip, c_out | |
# Compare LCMScheduler.step, Step 4 | |
def get_predicted_original_sample( | |
model_output, timesteps, sample, prediction_type, alphas, sigmas | |
): | |
alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
if prediction_type == "epsilon": | |
pred_x_0 = (sample - sigmas * model_output) / alphas | |
elif prediction_type == "sample": | |
pred_x_0 = model_output | |
elif prediction_type == "v_prediction": | |
pred_x_0 = alphas * sample - sigmas * model_output | |
else: | |
raise ValueError( | |
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" | |
f" are supported." | |
) | |
return pred_x_0 | |
# Based on step 4 in DDIMScheduler.step | |
def get_predicted_noise( | |
model_output, timesteps, sample, prediction_type, alphas, sigmas | |
): | |
alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
if prediction_type == "epsilon": | |
pred_epsilon = model_output | |
elif prediction_type == "sample": | |
pred_epsilon = (sample - alphas * model_output) / sigmas | |
elif prediction_type == "v_prediction": | |
pred_epsilon = alphas * model_output + sigmas * sample | |
else: | |
raise ValueError( | |
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" | |
f" are supported." | |
) | |
return pred_epsilon | |
# From LatentConsistencyModel.get_guidance_scale_embedding | |
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): | |
""" | |
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | |
Args: | |
timesteps (`torch.Tensor`): | |
generate embedding vectors at these timesteps | |
embedding_dim (`int`, *optional*, defaults to 512): | |
dimension of the embeddings to generate | |
dtype: | |
data type of the generated embeddings | |
Returns: | |
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` | |
""" | |
assert len(w.shape) == 1 | |
w = w * 1000.0 | |
half_dim = embedding_dim // 2 | |
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | |
emb = w.to(dtype)[:, None] * emb[None, :] | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
if embedding_dim % 2 == 1: # zero pad | |
emb = torch.nn.functional.pad(emb, (0, 1)) | |
assert emb.shape == (w.shape[0], embedding_dim) | |
return emb | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError( | |
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
) | |
return x[(...,) + (None,) * dims_to_append] | |
# From LCMScheduler.get_scalings_for_boundary_condition_discrete | |
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): | |
scaled_timestep = timestep_scaling * timestep | |
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) | |
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 | |
return c_skip, c_out | |
# Compare LCMScheduler.step, Step 4 | |
def get_predicted_original_sample( | |
model_output, timesteps, sample, prediction_type, alphas, sigmas | |
): | |
alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
if prediction_type == "epsilon": | |
pred_x_0 = (sample - sigmas * model_output) / alphas | |
elif prediction_type == "sample": | |
pred_x_0 = model_output | |
elif prediction_type == "v_prediction": | |
pred_x_0 = alphas * sample - sigmas * model_output | |
else: | |
raise ValueError( | |
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" | |
f" are supported." | |
) | |
return pred_x_0 | |
# Based on step 4 in DDIMScheduler.step | |
def get_predicted_noise( | |
model_output, timesteps, sample, prediction_type, alphas, sigmas | |
): | |
alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
if prediction_type == "epsilon": | |
pred_epsilon = model_output | |
elif prediction_type == "sample": | |
pred_epsilon = (sample - alphas * model_output) / sigmas | |
elif prediction_type == "v_prediction": | |
pred_epsilon = alphas * model_output + sigmas * sample | |
else: | |
raise ValueError( | |
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" | |
f" are supported." | |
) | |
return pred_epsilon | |
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): | |
extra_params = extra_params if len(extra_params.keys()) > 0 else None | |
return { | |
"model": model, | |
"condition": condition, | |
"extra_params": extra_params, | |
"is_lora": is_lora, | |
"negation": negation, | |
} | |
def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None): | |
params = {"name": name, "params": params, "lr": lr} | |
if extra_params is not None: | |
for k, v in extra_params.items(): | |
params[k] = v | |
return params | |
def create_optimizer_params(model_list, lr): | |
import itertools | |
optimizer_params = [] | |
for optim in model_list: | |
model, condition, extra_params, is_lora, negation = optim.values() | |
# Check if we are doing LoRA training. | |
if is_lora and condition and isinstance(model, list): | |
params = create_optim_params( | |
params=itertools.chain(*model), extra_params=extra_params | |
) | |
optimizer_params.append(params) | |
continue | |
if is_lora and condition and not isinstance(model, list): | |
for n, p in model.named_parameters(): | |
if "lora" in n: | |
params = create_optim_params(n, p, lr, extra_params) | |
optimizer_params.append(params) | |
continue | |
# If this is true, we can train it. | |
if condition: | |
for n, p in model.named_parameters(): | |
should_negate = "lora" in n and not is_lora | |
if should_negate: | |
continue | |
params = create_optim_params(n, p, lr, extra_params) | |
optimizer_params.append(params) | |
return optimizer_params | |
def handle_trainable_modules( | |
model, trainable_modules=None, is_enabled=True, negation=None | |
): | |
acc = [] | |
unfrozen_params = 0 | |
if trainable_modules is not None: | |
unlock_all = any([name == "all" for name in trainable_modules]) | |
if unlock_all: | |
model.requires_grad_(True) | |
unfrozen_params = len(list(model.parameters())) | |
else: | |
model.requires_grad_(False) | |
for name, param in model.named_parameters(): | |
for tm in trainable_modules: | |
if all([tm in name, name not in acc, "lora" not in name]): | |
param.requires_grad_(is_enabled) | |
acc.append(name) | |
unfrozen_params += 1 | |
def huber_loss(pred, target, huber_c=0.001): | |
loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c | |
return loss.mean() | |
def update_ema(target_params, source_params, rate=0.99): | |
""" | |
Update target parameters to be closer to those of source parameters using | |
an exponential moving average. | |
:param target_params: the target parameter sequence. | |
:param source_params: the source parameter sequence. | |
:param rate: the EMA rate (closer to 1 means slower). | |
""" | |
for targ, src in zip(target_params, source_params): | |
targ.detach().mul_(rate).add_(src, alpha=1 - rate) | |
def log_validation_video(pipeline, args, accelerator, save_fps): | |
if args.seed is None: | |
generator = None | |
else: | |
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | |
validation_prompts = [ | |
"An astronaut riding a horse.", | |
"Darth vader surfing in waves.", | |
"Robot dancing in times square.", | |
"Clown fish swimming through the coral reef.", | |
"A child excitedly swings on a rusty swing set, laughter filling the air.", | |
"With the style of van gogh, A young couple dances under the moonlight by the lake.", | |
"A young woman with glasses is jogging in the park wearing a pink headband.", | |
"Impressionist style, a yellow rubber duck floating on the wave on the sunset", | |
] | |
video_logs = [] | |
for _, prompt in enumerate(validation_prompts): | |
with torch.autocast("cuda"): | |
videos = pipeline( | |
prompt=prompt, | |
frames=args.n_frames, | |
num_inference_steps=4, | |
num_videos_per_prompt=2, | |
generator=generator, | |
) | |
videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 | |
videos = (videos * 255).to(torch.uint8).permute(0, 2, 1, 3, 4).cpu().numpy() | |
video_logs.append({"validation_prompt": prompt, "videos": videos}) | |
for tracker in accelerator.trackers: | |
if tracker.name == "wandb": | |
formatted_videos = [] | |
for log in video_logs: | |
videos = log["videos"] | |
validation_prompt = log["validation_prompt"] | |
for video in videos: | |
video = wandb.Video(video, caption=validation_prompt, fps=save_fps) | |
formatted_videos.append(video) | |
tracker.log({f"validation": formatted_videos}) | |
del pipeline | |
gc.collect() | |
def tuple_type(s): | |
if isinstance(s, tuple): | |
return s | |
value = ast.literal_eval(s) | |
if isinstance(value, tuple): | |
return value | |
raise TypeError("Argument must be a tuple") | |
def load_model_checkpoint(model, ckpt): | |
def load_checkpoint(model, ckpt, full_strict): | |
state_dict = torch.load(ckpt, map_location="cpu") | |
if "state_dict" in list(state_dict.keys()): | |
state_dict = state_dict["state_dict"] | |
model.load_state_dict(state_dict, strict=full_strict) | |
del state_dict | |
gc.collect() | |
return model | |
load_checkpoint(model, ckpt, full_strict=True) | |
print(">>> model checkpoint loaded.") | |
return model | |