alcm / ldm /models /diffusion /scheduling_lcm.py
inLine-XJY's picture
Upload 335 files
2b5b9ef verified
raw
history blame
24.9 kB
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from typing import List, Optional, Tuple, Union
from ldm.util import randn_tensor
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
class LCMSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.original_inference_steps = 100
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, self.ddpm_num_timesteps)[::-1].copy().astype(np.int64))
self.custom_timesteps = False
self.timestep_scaling = 10.0
self.prediction_type = 'epsilon'
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_discretize="uniform", verbose=True):
# self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
# num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
# alphas_cumprod = self.model.alphas_cumprod
# assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
# beta_start = 0.00085
# beta_end = 0.012
# self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, self.ddpm_num_timesteps, dtype=torch.float32) ** 2
# self.alphas = 1.0 - self.betas
# self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# # calculations for diffusion q(x_t | x_{t-1}) and others
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
# self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
# self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
# self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# # ddim sampling parameters
# ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
# ddim_timesteps=self.ddim_timesteps,
# eta=ddim_eta,verbose=verbose)
# self.register_buffer('ddim_sigmas', ddim_sigmas)
# self.register_buffer('ddim_alphas', ddim_alphas)
# self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
# self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
# sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
# (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
# 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
# self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def step_index(self):
return self._step_index
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
original_inference_steps: Optional[int] = None,
timesteps: Optional[List[int]] = None,
strength: int = 1.0,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
original_inference_steps (`int`, *optional*):
The original number of inference steps, which will be used to generate a linearly-spaced timestep
schedule (which is different from the standard `diffusers` implementation). We will then take
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
"""
# 0. Check inputs
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
# 1. Calculate the LCM original training/distillation timestep schedule.
original_steps = (
original_inference_steps if original_inference_steps is not None else self.original_inference_steps
)
if original_steps > self.ddpm_num_timesteps:
raise ValueError(
f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.ddpm_num_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.ddpm_num_timesteps} timesteps."
)
# import ipdb
# ipdb.set_trace()
# LCM Timesteps Setting
# The skipping step parameter k from the paper.
k = self.ddpm_num_timesteps // original_steps
# LCM Training/Distillation Steps Schedule
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
# 2. Calculate the LCM inference timestep schedule.
if timesteps is not None:
# 2.1 Handle custom timestep schedules.
train_timesteps = set(lcm_origin_timesteps)
non_train_timesteps = []
for i in range(1, len(timesteps)):
if timesteps[i] >= timesteps[i - 1]:
raise ValueError("`custom_timesteps` must be in descending order.")
if timesteps[i] not in train_timesteps:
non_train_timesteps.append(timesteps[i])
if timesteps[0] >= self.ddpm_num_timesteps:
raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:"
f" {self.ddpm_num_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
if strength == 1.0 and timesteps[0] != self.ddpm_num_timesteps - 1:
logger.warning(
f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
f" `self.ddpm_num_timesteps - 1`: {self.ddpm_num_timesteps - 1}. You may get"
f" unexpected results when using this timestep schedule."
)
# Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
if non_train_timesteps:
logger.warning(
f"The custom timestep schedule contains the following timesteps which are not on the original"
f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
f" when using this timestep schedule."
)
# Raise warning if custom timestep schedule is longer than original_steps
if len(timesteps) > original_steps:
logger.warning(
f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
f" the length of the timestep schedule used for training: {original_steps}. You may get some"
f" unexpected results when using this timestep schedule."
)
timesteps = np.array(timesteps, dtype=np.int64)
self.num_inference_steps = len(timesteps)
self.custom_timesteps = True
# Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
t_start = max(self.num_inference_steps - init_timestep, 0)
timesteps = timesteps[t_start * self.order :]
# TODO: also reset self.num_inference_steps?
else:
# 2.2 Create the "standard" LCM inference timestep schedule.
if num_inference_steps > self.ddpm_num_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.ddpm_num_timesteps`:"
f" {self.ddpm_num_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.ddpm_num_timesteps} timesteps."
)
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
if skipping_step < 1:
raise ValueError(
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
)
self.num_inference_steps = num_inference_steps
if num_inference_steps > original_steps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
f" {original_steps} because the final timestep schedule will be a subset of the"
f" `original_inference_steps`-sized initial timestep schedule."
)
# LCM Inference Steps Schedule
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
timesteps = lcm_origin_timesteps[inference_indices]
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
self._step_index = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
self.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = self.timesteps
num_inference_steps = len(timesteps)
else:
self.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = self.timesteps
return timesteps, num_inference_steps
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
verbose=True,
x_T=None,
guidance_scale=5.,
original_inference_steps=50,
timesteps=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(verbose=verbose)
self.num_inference_steps = S
# sampling
if len(shape)==3:
C, H, W = shape
size = (batch_size, C, H, W)
else:
C, T = shape
size = (batch_size, C, T)
samples, intermediates = self.lcm_sampling(conditioning, size,
x_T=x_T,
guidance_scale=guidance_scale,
original_inference_steps=original_inference_steps,
timesteps=timesteps
)
return samples, intermediates
@torch.no_grad()
def lcm_sampling(self, cond, shape,
x_T=None,
guidance_scale=1.,original_inference_steps=100,timesteps=None):
device = self.model.betas.device
timesteps, num_inference_steps = self.retrieve_timesteps(
self.num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps
)
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
w = torch.tensor(guidance_scale - 1).repeat(b)
w_embedding = self.get_guidance_scale_embedding(w, embedding_dim=256).to(
device=device, dtype=img.dtype
)
# import ipdb
# ipdb.set_trace()
# 8. LCM MultiStep Sampling Loop:
num_warmup_steps = len(timesteps) - num_inference_steps
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
img = img.to(cond.dtype)
ts = torch.full((b,), t, device=device, dtype=torch.long)
# model prediction (v-prediction, eps, x)
model_pred = self.model.apply_model(img, ts, cond,self.model.unet, w_cond=w_embedding)
# compute the previous noisy sample x_t -> x_t-1
img, denoised = self.step(model_pred, t, img, return_dict=False)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps):
progress_bar.update()
return denoised, img
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
scaled_timestep = timestep * self.timestep_scaling
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
@torch.no_grad()
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
):
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# 1. get previous step value
prev_step_index = self.step_index + 1
if prev_step_index < len(self.timesteps):
prev_timestep = self.timesteps[prev_step_index]
else:
prev_timestep = timestep
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else torch.tensor(1.0)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. Get scalings for boundary conditions
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
# 4. Compute the predicted original sample x_0 based on the model parameterization
if self.prediction_type == "epsilon": # noise-prediction
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
elif self.prediction_type == "sample": # x-prediction
predicted_original_sample = model_output
elif self.prediction_type == "v_prediction": # v-prediction
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or"
" `v_prediction` for `LCMScheduler`."
)
# 5. Denoise model output using boundary conditions
denoised = c_out * predicted_original_sample + c_skip * sample
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
# Noise is not used on the final timestep of the timestep schedule.
# This also means that noise is not used for one-step sampling.
if self.step_index != self.num_inference_steps - 1:
noise = torch.randn(model_output.shape, device=model_output.device)
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample, denoised)
return prev_sample, denoised