import ast import gc import torch from collections import OrderedDict from diffusers.models.attention_processor import AttnProcessor2_0 from diffusers.models.attention import BasicTransformerBlock import wandb def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def is_attn(name): return "attn1" or "attn2" == name.split(".")[-1] def set_processors(attentions): for attn in attentions: attn.set_processor(AttnProcessor2_0()) def set_torch_2_attn(unet): optim_count = 0 for name, module in unet.named_modules(): if is_attn(name): if isinstance(module, torch.nn.ModuleList): for m in module: if isinstance(m, BasicTransformerBlock): set_processors([m.attn1, m.attn2]) optim_count += 1 if optim_count > 0: print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") # From LatentConsistencyModel.get_guidance_scale_embedding def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: timesteps (`torch.Tensor`): generate embedding vectors at these timesteps embedding_dim (`int`, *optional*, defaults to 512): dimension of the embeddings to generate dtype: data type of the generated embeddings Returns: `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): scaled_timestep = timestep_scaling * timestep c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out # Compare LCMScheduler.step, Step 4 def get_predicted_original_sample( model_output, timesteps, sample, prediction_type, alphas, sigmas ): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": pred_x_0 = (sample - sigmas * model_output) / alphas elif prediction_type == "sample": pred_x_0 = model_output elif prediction_type == "v_prediction": pred_x_0 = alphas * sample - sigmas * model_output else: raise ValueError( f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) return pred_x_0 # Based on step 4 in DDIMScheduler.step def get_predicted_noise( model_output, timesteps, sample, prediction_type, alphas, sigmas ): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": pred_epsilon = model_output elif prediction_type == "sample": pred_epsilon = (sample - alphas * model_output) / sigmas elif prediction_type == "v_prediction": pred_epsilon = alphas * model_output + sigmas * sample else: raise ValueError( f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) return pred_epsilon # From LatentConsistencyModel.get_guidance_scale_embedding def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: timesteps (`torch.Tensor`): generate embedding vectors at these timesteps embedding_dim (`int`, *optional*, defaults to 512): dimension of the embeddings to generate dtype: data type of the generated embeddings Returns: `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] # From LCMScheduler.get_scalings_for_boundary_condition_discrete def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): scaled_timestep = timestep_scaling * timestep c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 return c_skip, c_out # Compare LCMScheduler.step, Step 4 def get_predicted_original_sample( model_output, timesteps, sample, prediction_type, alphas, sigmas ): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": pred_x_0 = (sample - sigmas * model_output) / alphas elif prediction_type == "sample": pred_x_0 = model_output elif prediction_type == "v_prediction": pred_x_0 = alphas * sample - sigmas * model_output else: raise ValueError( f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) return pred_x_0 # Based on step 4 in DDIMScheduler.step def get_predicted_noise( model_output, timesteps, sample, prediction_type, alphas, sigmas ): alphas = extract_into_tensor(alphas, timesteps, sample.shape) sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) if prediction_type == "epsilon": pred_epsilon = model_output elif prediction_type == "sample": pred_epsilon = (sample - alphas * model_output) / sigmas elif prediction_type == "v_prediction": pred_epsilon = alphas * model_output + sigmas * sample else: raise ValueError( f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" f" are supported." ) return pred_epsilon def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): extra_params = extra_params if len(extra_params.keys()) > 0 else None return { "model": model, "condition": condition, "extra_params": extra_params, "is_lora": is_lora, "negation": negation, } def create_optim_params(name="param", params=None, lr=5e-6, extra_params=None): params = {"name": name, "params": params, "lr": lr} if extra_params is not None: for k, v in extra_params.items(): params[k] = v return params def create_optimizer_params(model_list, lr): import itertools optimizer_params = [] for optim in model_list: model, condition, extra_params, is_lora, negation = optim.values() # Check if we are doing LoRA training. if is_lora and condition and isinstance(model, list): params = create_optim_params( params=itertools.chain(*model), extra_params=extra_params ) optimizer_params.append(params) continue if is_lora and condition and not isinstance(model, list): for n, p in model.named_parameters(): if "lora" in n: params = create_optim_params(n, p, lr, extra_params) optimizer_params.append(params) continue # If this is true, we can train it. if condition: for n, p in model.named_parameters(): should_negate = "lora" in n and not is_lora if should_negate: continue params = create_optim_params(n, p, lr, extra_params) optimizer_params.append(params) return optimizer_params def handle_trainable_modules( model, trainable_modules=None, is_enabled=True, negation=None ): acc = [] unfrozen_params = 0 if trainable_modules is not None: unlock_all = any([name == "all" for name in trainable_modules]) if unlock_all: model.requires_grad_(True) unfrozen_params = len(list(model.parameters())) else: model.requires_grad_(False) for name, param in model.named_parameters(): for tm in trainable_modules: if all([tm in name, name not in acc, "lora" not in name]): param.requires_grad_(is_enabled) acc.append(name) unfrozen_params += 1 def huber_loss(pred, target, huber_c=0.001): loss = torch.sqrt((pred.float() - target.float()) ** 2 + huber_c**2) - huber_c return loss.mean() @torch.no_grad() def update_ema(target_params, source_params, rate=0.99): """ Update target parameters to be closer to those of source parameters using an exponential moving average. :param target_params: the target parameter sequence. :param source_params: the source parameter sequence. :param rate: the EMA rate (closer to 1 means slower). """ for targ, src in zip(target_params, source_params): targ.detach().mul_(rate).add_(src, alpha=1 - rate) def log_validation_video(pipeline, args, accelerator, save_fps): if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) validation_prompts = [ "An astronaut riding a horse.", "Darth vader surfing in waves.", "Robot dancing in times square.", "Clown fish swimming through the coral reef.", "A child excitedly swings on a rusty swing set, laughter filling the air.", "With the style of van gogh, A young couple dances under the moonlight by the lake.", "A young woman with glasses is jogging in the park wearing a pink headband.", "Impressionist style, a yellow rubber duck floating on the wave on the sunset", ] video_logs = [] for _, prompt in enumerate(validation_prompts): with torch.autocast("cuda"): videos = pipeline( prompt=prompt, frames=args.n_frames, num_inference_steps=4, num_videos_per_prompt=2, generator=generator, ) videos = (videos.clamp(-1.0, 1.0) + 1.0) / 2.0 videos = (videos * 255).to(torch.uint8).permute(0, 2, 1, 3, 4).cpu().numpy() video_logs.append({"validation_prompt": prompt, "videos": videos}) for tracker in accelerator.trackers: if tracker.name == "wandb": formatted_videos = [] for log in video_logs: videos = log["videos"] validation_prompt = log["validation_prompt"] for video in videos: video = wandb.Video(video, caption=validation_prompt, fps=save_fps) formatted_videos.append(video) tracker.log({f"validation": formatted_videos}) del pipeline gc.collect() def tuple_type(s): if isinstance(s, tuple): return s value = ast.literal_eval(s) if isinstance(value, tuple): return value raise TypeError("Argument must be a tuple") def load_model_checkpoint(model, ckpt): def load_checkpoint(model, ckpt, full_strict): state_dict = torch.load(ckpt, map_location="cpu") if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] model.load_state_dict(state_dict, strict=full_strict) del state_dict gc.collect() return model load_checkpoint(model, ckpt, full_strict=True) print(">>> model checkpoint loaded.") return model