import os, math, random, argparse, logging from pathlib import Path from typing import Optional, Union, List, Callable from collections import OrderedDict from packaging import version from tqdm.auto import tqdm from omegaconf import OmegaConf import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import torchvision class PFODESolver(): def __init__(self, scheduler, t_initial=1, t_terminal=0,) -> None: self.t_initial = t_initial self.t_terminal = t_terminal self.scheduler = scheduler train_step_terminal = 0 train_step_initial = train_step_terminal + self.scheduler.config.num_train_timesteps # 0+1000 self.stepsize = (t_terminal-t_initial) / (train_step_terminal - train_step_initial) #1/1000 def get_timesteps(self, t_start, t_end, num_steps): # (b,) -> (b,1) t_start = t_start[:, None] t_end = t_end[:, None] assert t_start.dim() == 2 timepoints = torch.arange(0, num_steps, 1).expand(t_start.shape[0], num_steps).to(device=t_start.device) interval = (t_end - t_start) / (torch.ones([1], device=t_start.device) * num_steps) timepoints = t_start + interval * timepoints timesteps = (self.scheduler.num_train_timesteps - 1) + (timepoints - self.t_initial) / self.stepsize # correspondint to StableDiffusion indexing system, from 999 (t_init) -> 0 (dt) return timesteps.round().long() def solve(self, latents, unet, t_start, t_end, prompt_embeds, negative_prompt_embeds, guidance_scale=1.0, num_steps = 2, num_windows = 1, ): assert t_start.dim() == 1 assert guidance_scale >= 1 and torch.all(torch.gt(t_start, t_end)) do_classifier_free_guidance = True if guidance_scale > 1 else False bsz = latents.shape[0] if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) timestep_cond = None if unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(guidance_scale - 1).repeat(bsz) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=unet.config.time_cond_proj_dim ).to(device=latents.device, dtype=latents.dtype) timesteps = self.get_timesteps(t_start, t_end, num_steps).to(device=latents.device) timestep_interval = self.scheduler.config.num_train_timesteps // (num_windows * num_steps) # Denoising loop with torch.no_grad(): for i in range(num_steps): t = torch.cat([timesteps[:, i]]*2) if do_classifier_free_guidance else timesteps[:, i] latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) noise_pred = unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, return_dict=False, )[0] if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) ##### STEP: compute the previous noisy sample x_t -> x_t-1 batch_timesteps = timesteps[:, i].cpu() prev_timestep = batch_timesteps - timestep_interval alpha_prod_t = self.scheduler.alphas_cumprod[batch_timesteps] alpha_prod_t_prev = torch.zeros_like(alpha_prod_t) for ib in range(prev_timestep.shape[0]): alpha_prod_t_prev[ib] = self.scheduler.alphas_cumprod[prev_timestep[ib]] if prev_timestep[ib] >= 0 else self.scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t alpha_prod_t = alpha_prod_t.to(device=latents.device, dtype=latents.dtype) alpha_prod_t_prev = alpha_prod_t_prev.to(device=latents.device, dtype=latents.dtype) beta_prod_t = beta_prod_t.to(device=latents.device, dtype=latents.dtype) if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (latents - beta_prod_t[:,None,None,None] ** (0.5) * noise_pred) / alpha_prod_t[:, None,None,None] ** (0.5) pred_epsilon = noise_pred elif self.scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t[:,None,None,None]**0.5) * latents - (beta_prod_t[:,None,None,None]**0.5) * noise_pred pred_epsilon = (alpha_prod_t[:,None,None,None]**0.5) * noise_pred + (beta_prod_t[:,None,None,None]**0.5) * latents else: raise NotImplementedError pred_sample_direction = (1 - alpha_prod_t_prev[:,None,None,None]) ** (0.5) * pred_epsilon latents = alpha_prod_t_prev[:,None,None,None] ** (0.5) * pred_original_sample + pred_sample_direction return latents