|
import ast |
|
from safetensors import safe_open |
|
import torch |
|
from dataclasses import dataclass |
|
from typing import Optional, Union, List |
|
|
|
def update_args_from_yaml(group, args, parser): |
|
for key, value in group.items(): |
|
if isinstance(value, dict): |
|
update_args_from_yaml(value, args, parser) |
|
else: |
|
if value == 'None' or value == 'null': |
|
value = None |
|
else: |
|
arg_type = next((action.type for action in parser._actions if action.dest == key), str) |
|
|
|
if arg_type is ast.literal_eval: |
|
pass |
|
elif arg_type is not None and not isinstance(value, arg_type): |
|
try: |
|
value = arg_type(value) |
|
except ValueError as e: |
|
raise ValueError(f"Cannot convert {key} to {arg_type}: {e}") |
|
|
|
setattr(args, key, value) |
|
|
|
|
|
def safe_load(model_path): |
|
assert "safetensors" in model_path |
|
state_dict = {} |
|
with safe_open(model_path, framework="pt", device="cpu") as f: |
|
for k in f.keys(): |
|
state_dict[k] = f.get_tensor(k) |
|
return state_dict |
|
|
|
|
|
@dataclass |
|
class DDIMSchedulerStepOutput: |
|
prev_sample: torch.Tensor |
|
pred_original_sample: Optional[torch.Tensor] = None |
|
|
|
|
|
@dataclass |
|
class DDIMSchedulerConversionOutput: |
|
pred_epsilon: torch.Tensor |
|
pred_original_sample: torch.Tensor |
|
pred_velocity: torch.Tensor |
|
|
|
|
|
class DDIMScheduler: |
|
prediction_types = ["epsilon", "sample", "v_prediction"] |
|
|
|
def __init__( |
|
self, |
|
num_train_timesteps: int, |
|
num_inference_timesteps: int, |
|
betas: torch.Tensor, |
|
set_alpha_to_one: bool = True, |
|
set_inference_timesteps_from_pure_noise: bool = True, |
|
inference_timesteps: Union[str, List[int]] = "trailing", |
|
device: Optional[Union[str, torch.device]] = None, |
|
dtype: torch.dtype = torch.float32, |
|
skip_step:bool = False, |
|
original_inference_step: int=20, |
|
steps_offset: int=0, |
|
|
|
): |
|
assert num_train_timesteps > 0 |
|
assert num_train_timesteps >= num_inference_timesteps |
|
assert num_train_timesteps == betas.size(0) |
|
assert betas.ndim == 1 |
|
|
|
|
|
|
|
self.module_name = 'AutoAIGC' |
|
self.config_list = {"num_train_timesteps": num_train_timesteps, |
|
"num_inference_timesteps": num_inference_timesteps, |
|
"betas": betas, |
|
"set_alpha_to_one": set_alpha_to_one, |
|
"set_inference_timesteps_from_pure_noise": set_inference_timesteps_from_pure_noise, |
|
"inference_timesteps": inference_timesteps} |
|
self.module_info = str(self.config_list) |
|
|
|
|
|
|
|
device = device or betas.device |
|
|
|
self.num_train_timesteps = num_train_timesteps |
|
self.num_inference_steps = num_inference_timesteps |
|
self.steps_offset = steps_offset |
|
|
|
self.betas = betas |
|
self.alphas = 1.0 - self.betas |
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
|
self.final_alpha_cumprod = torch.tensor(1.0, device=device, dtype=dtype) if set_alpha_to_one else self.alphas_cumprod[0] |
|
|
|
if isinstance(inference_timesteps, torch.Tensor): |
|
assert len(inference_timesteps) == num_inference_timesteps |
|
self.timesteps = inference_timesteps.cpu().numpy().tolist() |
|
elif set_inference_timesteps_from_pure_noise: |
|
if inference_timesteps == "trailing": |
|
|
|
if skip_step: |
|
original_timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / original_inference_step, device=device).round().int().tolist() |
|
skipping_step = len(original_timesteps) // num_inference_timesteps |
|
self.timesteps = original_timesteps[::skipping_step][:num_inference_timesteps] |
|
else: |
|
self.timesteps = torch.arange(num_train_timesteps - 1, -1, -num_train_timesteps / num_inference_timesteps, device=device).round().int().tolist() |
|
elif inference_timesteps == "linspace": |
|
|
|
|
|
|
|
|
|
self.timesteps = torch.linspace(0, num_train_timesteps - 1, num_inference_timesteps, device=device).round().int().flip(0).tolist() |
|
elif inference_timesteps == "leading": |
|
step_ratio = num_train_timesteps // num_inference_timesteps |
|
|
|
|
|
self.timesteps = torch.arange(0, num_inference_timesteps).mul(step_ratio).round().flip(dims=[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
elif inference_timesteps == "leading": |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps))) |
|
|
|
else: |
|
self.timesteps = list(reversed(range(0, num_train_timesteps, num_train_timesteps // num_inference_timesteps))) |
|
|
|
|
|
self.to(device=device) |
|
|
|
|
|
def to(self, device): |
|
self.betas = self.betas.to(device) |
|
self.alphas_cumprod = self.alphas_cumprod.to(device) |
|
self.final_alpha_cumprod = self.final_alpha_cumprod.to(device) |
|
|
|
return self |
|
|
|
def step( |
|
self, |
|
model_output: torch.Tensor, |
|
model_output_type: str, |
|
timestep: Union[torch.Tensor, int], |
|
sample: torch.Tensor, |
|
eta: float = 0.0, |
|
clip_sample: bool = False, |
|
dynamic_threshold: Optional[float] = None, |
|
variance_noise: Optional[torch.Tensor] = None, |
|
) -> DDIMSchedulerStepOutput: |
|
|
|
if isinstance(timestep, int): |
|
|
|
idx = self.timesteps.index(timestep) |
|
prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None |
|
|
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
else: |
|
timesteps = torch.tensor(self.timesteps).to(timestep.device) |
|
idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] |
|
|
|
prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)] |
|
|
|
assert (prev_timestep is not None) |
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] |
|
alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev) |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
bs = timestep.size(0) |
|
alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1) |
|
alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1) |
|
beta_prod_t = beta_prod_t.view(bs, 1, 1, 1) |
|
beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stock_alpha_prod_t_prev = alpha_prod_t_prev |
|
self.stock_beta_prod_t_prev = beta_prod_t_prev |
|
|
|
|
|
self.stock_alpha_prod_t_prev = alpha_prod_t_prev |
|
self.stock_beta_prod_t_prev = beta_prod_t_prev |
|
|
|
|
|
model_output_conversion = self.convert_output(model_output, model_output_type, sample, timestep) |
|
pred_original_sample = model_output_conversion.pred_original_sample |
|
pred_epsilon = model_output_conversion.pred_epsilon |
|
|
|
|
|
if clip_sample: |
|
pred_original_sample = torch.clamp(pred_original_sample, -1, 1) |
|
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon |
|
|
|
if dynamic_threshold is not None: |
|
|
|
dynamic_max_val = pred_original_sample \ |
|
.flatten(1) \ |
|
.abs() \ |
|
.float() \ |
|
.quantile(dynamic_threshold, dim=1) \ |
|
.type_as(pred_original_sample) \ |
|
.clamp_min(1) \ |
|
.view(-1, *([1] * (pred_original_sample.ndim - 1))) |
|
pred_original_sample = pred_original_sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val |
|
pred_epsilon = self.convert_output(pred_original_sample, "sample", sample, timestep).pred_epsilon |
|
|
|
|
|
|
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) |
|
std_dev_t = eta * variance ** (0.5) |
|
|
|
|
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon |
|
|
|
|
|
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
|
|
|
|
|
if eta > 0: |
|
if variance_noise is None: |
|
variance_noise = torch.randn_like(model_output) |
|
prev_sample = prev_sample + std_dev_t * variance_noise |
|
|
|
return DDIMSchedulerStepOutput( |
|
prev_sample=prev_sample, |
|
pred_original_sample=pred_original_sample |
|
) |
|
|
|
def add_noise( |
|
self, |
|
original_samples: torch.Tensor, |
|
noise: torch.Tensor, |
|
timesteps: Union[torch.Tensor, int], |
|
replace_noise=True |
|
) -> torch.Tensor: |
|
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (original_samples.ndim - 1))) |
|
if replace_noise: |
|
indices = (timesteps == 999).nonzero() |
|
if indices.numel() > 0: |
|
alpha_prod_t[indices] = 0 |
|
return alpha_prod_t ** (0.5) * original_samples + (1 - alpha_prod_t) ** (0.5) * noise |
|
|
|
def add_noise_lcm( |
|
self, |
|
original_samples: torch.Tensor, |
|
noise: torch.Tensor, |
|
timestep: Union[torch.Tensor, int], |
|
) -> torch.Tensor: |
|
if isinstance(timestep, int): |
|
|
|
idx = self.timesteps.index(timestep) |
|
prev_timestep = self.timesteps[idx + 1] if idx < self.num_inference_steps - 1 else None |
|
|
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep is not None else self.final_alpha_cumprod |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
else: |
|
timesteps = torch.tensor(self.timesteps).to(timestep.device) |
|
idx = timestep.reshape(-1, 1).eq(timesteps.reshape(1, -1)).nonzero()[:, 1] |
|
prev_timestep = timesteps[idx.add(1).clamp_max(self.num_inference_steps - 1)] |
|
|
|
assert (prev_timestep is not None) |
|
|
|
alpha_prod_t = self.alphas_cumprod[timestep] |
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] |
|
alpha_prod_t_prev = torch.where(prev_timestep < 0, self.final_alpha_cumprod, alpha_prod_t_prev) |
|
beta_prod_t = 1 - alpha_prod_t |
|
beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
|
|
bs = timestep.size(0) |
|
alpha_prod_t = alpha_prod_t.view(bs, 1, 1, 1) |
|
alpha_prod_t_prev = alpha_prod_t_prev.view(bs, 1, 1, 1) |
|
beta_prod_t = beta_prod_t.view(bs, 1, 1, 1) |
|
beta_prod_t_prev = beta_prod_t_prev.view(bs, 1, 1, 1) |
|
|
|
alpha_prod_t_prev = alpha_prod_t_prev.reshape(-1, *([1] * (original_samples.ndim - 1))) |
|
return alpha_prod_t_prev ** (0.5) * original_samples + (1 - alpha_prod_t_prev) ** (0.5) * noise |
|
|
|
|
|
def convert_output( |
|
self, |
|
model_output: torch.Tensor, |
|
model_output_type: str, |
|
sample: torch.Tensor, |
|
timesteps: Union[torch.Tensor, int] |
|
) -> DDIMSchedulerConversionOutput: |
|
assert model_output_type in self.prediction_types |
|
|
|
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1))) |
|
beta_prod_t = 1 - alpha_prod_t |
|
|
|
if model_output_type == "epsilon": |
|
pred_epsilon = model_output |
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * pred_epsilon) / alpha_prod_t ** (0.5) |
|
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample |
|
elif model_output_type == "sample": |
|
pred_original_sample = model_output |
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
|
pred_velocity = alpha_prod_t ** (0.5) * pred_epsilon - (1 - alpha_prod_t) ** (0.5) * pred_original_sample |
|
elif model_output_type == "v_prediction": |
|
pred_velocity = model_output |
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output |
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample |
|
else: |
|
raise ValueError("Unknown prediction type") |
|
|
|
return DDIMSchedulerConversionOutput( |
|
pred_epsilon=pred_epsilon, |
|
pred_original_sample=pred_original_sample, |
|
pred_velocity=pred_velocity) |
|
|
|
def get_velocity( |
|
self, |
|
sample: torch.Tensor, |
|
noise: torch.Tensor, |
|
timesteps: torch.Tensor |
|
) -> torch.FloatTensor: |
|
alpha_prod_t = self.alphas_cumprod[timesteps].reshape(-1, *([1] * (sample.ndim - 1))) |
|
return alpha_prod_t ** (0.5) * noise - (1 - alpha_prod_t) ** (0.5) * sample |
|
|