S2O_DPM / utils.py
Mayuri's picture
Upload 10 files
2a5630b verified
import ast
from safetensors import safe_open
import torch
from dataclasses import dataclass
from typing import Optional, Union, List
def update_args_from_yaml(group, args, parser):
for key, value in group.items():
if isinstance(value, dict):
update_args_from_yaml(value, args, parser)
else:
if value == 'None' or value == 'null':
value = None
else:
arg_type = next((action.type for action in parser._actions if action.dest == key), str)
if arg_type is ast.literal_eval:
pass
elif arg_type is not None and not isinstance(value, arg_type):
try:
value = arg_type(value)
except ValueError as e:
raise ValueError(f"Cannot convert {key} to {arg_type}: {e}")
setattr(args, key, value)
def safe_load(model_path):
assert "safetensors" in model_path
state_dict = {}
with safe_open(model_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
return state_dict
@dataclass
class DDIMSchedulerStepOutput:
prev_sample: torch.Tensor # x_{t-1}
pred_original_sample: Optional[torch.Tensor] = None # x0
@dataclass
class DDIMSchedulerConversionOutput:
pred_epsilon: torch.Tensor
pred_original_sample: torch.Tensor
pred_velocity: torch.Tensor
class DDIMScheduler:
prediction_types = ["epsilon", "sample", "v_prediction"]
def __init__(
self,
num_train_timesteps: int,
num_inference_timesteps: int,
betas: torch.Tensor,
set_alpha_to_one: bool = True,
set_inference_timesteps_from_pure_noise: bool = True,
inference_timesteps: Union[str, List[int]] = "trailing",
device: Optional[Union[str, torch.device]] = None,
dtype: torch.dtype = torch.float32,
skip_step:bool = False,
original_inference_step: int=20,
steps_offset: int=0,
):
assert num_train_timesteps > 0
assert num_train_timesteps >= num_inference_timesteps
assert num_train_timesteps == betas.size(0)
assert betas.ndim == 1
# self.user_name = user_name
# self.run_time = Recorder.format_time()
# self.task_name = 'AutoAIGC_%s' % str(self.run_time)
self.module_name = 'AutoAIGC'
self.config_list = {"num_train_timesteps": num_train_timesteps,
"num_inference_timesteps": num_inference_timesteps,
"betas": betas,
"set_alpha_to_one": set_alpha_to_one,
"set_inference_timesteps_from_pure_noise": set_inference_timesteps_from_pure_noise,
"inference_timesteps": inference_timesteps}
self.module_info = str(self.config_list)
# self.upload_logger(user_name=user_name)
device = device or betas.device
self.num_train_timesteps = num_train_timesteps
self.num_inference_steps = num_inference_timesteps
self.steps_offset = steps_offset
self.betas = betas # .to(device=device, dtype=dtype)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.final_alpha_cumprod = torch.tensor(1.0, device=device, dtype=dtype) if set_alpha_to_one else self.alphas_cumprod[0]
if isinstance(inference_timesteps, torch.Tensor):
assert len(inference_timesteps) == num_inference_timesteps
self.timesteps = inference_timesteps.cpu().numpy().tolist()
elif set_inference_timesteps_from_pure_noise:
if inference_timesteps == "trailing":
# [999, 949, 899, 849, 799, 749, 699, 649, 599, 549, 499, 449, 399, 349, 299, 249, 199, 149, 99, 49]
if skip_step: # ?
original_timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / original_inference_step, device=device).round().int().tolist()
skipping_step = len(original_timesteps) // num_inference_timesteps
self.timesteps = original_timesteps[::skipping_step][:num_inference_timesteps]
else: # [999, 899, 799, 699, 599, 499, 399, 299, 199, 99]
self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=device).round().int().tolist()
elif inference_timesteps == "linspace":
# Fixed DDIM timestep. Make sure the timestep starts from 999.
# Example 20 steps:
# [999, 946, 894, 841, 789, 736, 684, 631, 578, 526, 473, 421, 368, 315, 263, 210, 158, 105, 53, 0]
# [999, 888, 777, 666, 555, 444, 333, 222, 111, 0]
self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=device).round().int().flip(0).tolist()
elif inference_timesteps == "leading":
step_ratio = num_train_timesteps // num_inference_timesteps
# # creates integer timesteps by multiplying by ratio
# # casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = torch.arange(0, num_inference_timesteps).mul(step_ratio).round().flip(dims=[0]) #.clone().long()
# self.timesteps += self.steps_offset
# Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
# The inference timestep does not start from 999.
# Example 20 steps:
# [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0]
# [ 900, 800, 700, 600, 500, 400, 300, 200, 100, 0]
# self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
# self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
else:
raise NotImplementedError
elif inference_timesteps == "leading":
# Original SD and DDIM paper may have a bug: <https://github.com/huggingface/diffusers/issues/2585>
# The inference timestep does not start from 999.
# Example 20 steps:
# [950, 900, 850, 800, 750, 700, 650, 600, 550, 500, 450, 400, 350, 300, 250, 200, 150, 100, 50, 0]
# [ 900, 800, 700, 600, 500, 400, 300, 200, 100, 0]
# self.timesteps = torch.arange(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps, device=self.device, dtype=torch.int).flip(0)
self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
else:
self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps)))
# raise NotImplementedError
self.to(device=device)
def to(self, device):
self.betas = self.betas.to(device)
self.alphas_cumprod = self.alphas_cumprod.to(device)
self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)
# self.timesteps = self.timesteps.to(device)
return self
def step(
self,
model_output: torch.Tensor,
model_output_type: str,
timestep: Union[torch.Tensor, int],
sample: torch.Tensor,
eta: float = 0.0,
clip_sample: bool = False,
dynamic_threshold: Optional[float] = None,
variance_noise: Optional[torch.Tensor] = None,
) -> DDIMSchedulerStepOutput:
# 1. get previous step value (t-1)
if isinstance(timestep, int):
# 1. get previous step value (t-1)
idx = self.timesteps.index(timestep)
prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
else:
timesteps = torch.tensor(self.timesteps).to(timestep.device)
idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx
# 根据idx找到idx+1对应的timesteps元素,也就是下一个时间步。如果idx+1超出了timesteps的长度,它会被限制在self.num_inference_steps - 1
prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)]
assert (prev_timestep is not None)
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
bs = timestep.size(0)
alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1)
alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1)
beta_prod_t = beta_prod_t.view(bs, 1, 1, 1)
beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1)
# # 2. compute alphas, betas
# alpha_prod_t = self.alphas_cumprod[timestep]
# alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
# beta_prod_t = 1 - alpha_prod_t
# beta_prod_t_prev = 1 - alpha_prod_t_prev
# rcfg
self.stock_alpha_prod_t_prev = alpha_prod_t_prev
self.stock_beta_prod_t_prev = beta_prod_t_prev
# rcfg
self.stock_alpha_prod_t_prev = alpha_prod_t_prev
self.stock_beta_prod_t_prev = beta_prod_t_prev
# 3. compute predicted original sample from predicted noise also called
model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep)
pred_original_sample = model_output_conversion.pred_original_sample
pred_epsilon = model_output_conversion.pred_epsilon
# 4. Clip or threshold "predicted x_0"
if clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
if dynamic_threshold is not None:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = pred_original_sample \
.flatten(1) \
.abs() \
.float() \
.quantile(dynamic_threshold, dim=1) \
.type_as(pred_original_sample) \
.clamp_min(1) \
.view(-1, *([1] * (pred_original_sample.ndim - 1)))
pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon
# 5. compute variance: "sigma_t(η)" -> see formula (16) from https://arxiv.org/pdf/2010.02502.pdf
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
std_dev_t = eta * variance ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
# 8. add "random noise" if needed.
if eta > 0:
if variance_noise is None:
variance_noise = torch.randn_like(model_output)
prev_sample = prev_sample + std_dev_t * variance_noise
return DDIMSchedulerStepOutput(
prev_sample=prev_sample, # x_{t-1}
pred_original_sample=pred_original_sample # x0
)
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: Union[torch.Tensor, int],
replace_noise=True
) -> torch.Tensor:
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1)))
if replace_noise:
indices = (timesteps == 999).nonzero()
if indices.numel() > 0:
alpha_prod_t[indices] = 0
return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise
def add_noise_lcm(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timestep: Union[torch.Tensor, int],
) -> torch.Tensor:
if isinstance(timestep, int):
# 1. get previous step value (t-1)
idx = self.timesteps.index(timestep)
prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
else:
timesteps = torch.tensor(self.timesteps).to(timestep.device)
idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] # 找到 timestep 在 timesteps 中的索引 idx
prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)]
assert (prev_timestep is not None)
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
bs = timestep.size(0)
alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1)
alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1)
beta_prod_t = beta_prod_t.view(bs, 1, 1, 1)
beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1)
alpha_prod_t_prev = alpha_prod_t_prev.reshape(-1, *([1] * (original_samples.ndim - 1)))
return alpha_prod_t_prev ** (0.5) * original_samples + (1 - alpha_prod_t_prev) ** (0.5) * noise
def convert_output(
self,
model_output: torch.Tensor,
model_output_type: str,
sample: torch.Tensor,
timesteps: Union[torch.Tensor, int]
) -> DDIMSchedulerConversionOutput:
assert model_output_type in self.prediction_types
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
beta_prod_t = 1 - alpha_prod_t
if model_output_type == "epsilon":
pred_epsilon = model_output
pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5)
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
elif model_output_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample
elif model_output_type == "v_prediction":
pred_velocity = model_output
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError("Unknown prediction type")
return DDIMSchedulerConversionOutput(
pred_epsilon=pred_epsilon,
pred_original_sample=pred_original_sample,
pred_velocity=pred_velocity)
def get_velocity(
self,
sample: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor
) -> torch.FloatTensor:
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1)))
return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample