|
import torch |
|
import numpy as np |
|
from safetensors import safe_open |
|
|
|
|
|
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): |
|
""" |
|
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 |
|
|
|
Args: |
|
timesteps (`torch.Tensor`): |
|
generate embedding vectors at these timesteps |
|
embedding_dim (`int`, *optional*, defaults to 512): |
|
dimension of the embeddings to generate |
|
dtype: |
|
data type of the generated embeddings |
|
|
|
Returns: |
|
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` |
|
""" |
|
assert len(w.shape) == 1 |
|
w = w * 1000.0 |
|
|
|
half_dim = embedding_dim // 2 |
|
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) |
|
emb = w.to(dtype)[:, None] * emb[None, :] |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
|
if embedding_dim % 2 == 1: |
|
emb = torch.nn.functional.pad(emb, (0, 1)) |
|
assert emb.shape == (w.shape[0], embedding_dim) |
|
return emb |
|
|
|
|
|
def append_dims(x, target_dims): |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError( |
|
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
|
|
|
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): |
|
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) |
|
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 |
|
return c_skip, c_out |
|
|
|
|
|
|
|
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): |
|
if prediction_type == "epsilon": |
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
pred_x_0 = (sample - sigmas * model_output) / alphas |
|
elif prediction_type == "v_prediction": |
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
pred_x_0 = alphas * sample - sigmas * model_output |
|
else: |
|
raise ValueError( |
|
f"Prediction type {prediction_type} currently not supported.") |
|
|
|
return pred_x_0 |
|
|
|
|
|
def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas): |
|
if prediction_type == "epsilon": |
|
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) |
|
alphas = extract_into_tensor(alphas, timesteps, sample.shape) |
|
sample = sample * alphas / sigmas |
|
else: |
|
raise ValueError( |
|
f"Prediction type {prediction_type} currently not supported.") |
|
|
|
return sample |
|
|
|
|
|
def extract_into_tensor(a, t, x_shape): |
|
b, *_ = t.shape |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
class DDIMSolver: |
|
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): |
|
|
|
step_ratio = timesteps // ddim_timesteps |
|
self.ddim_timesteps = ( |
|
np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 |
|
|
|
self.ddim_timesteps_prev = np.asarray( |
|
[0] + self.ddim_timesteps[:-1].tolist() |
|
) |
|
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] |
|
self.ddim_alpha_cumprods_prev = np.asarray( |
|
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() |
|
) |
|
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] |
|
self.ddim_alpha_cumprods_prev = np.asarray( |
|
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() |
|
) |
|
|
|
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() |
|
self.ddim_timesteps_prev = torch.from_numpy( |
|
self.ddim_timesteps_prev).long() |
|
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) |
|
self.ddim_alpha_cumprods_prev = torch.from_numpy( |
|
self.ddim_alpha_cumprods_prev) |
|
|
|
def to(self, device): |
|
self.ddim_timesteps = self.ddim_timesteps.to(device) |
|
self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device) |
|
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) |
|
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to( |
|
device) |
|
return self |
|
|
|
def ddim_step(self, pred_x0, pred_noise, timestep_index): |
|
alpha_cumprod_prev = extract_into_tensor( |
|
self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) |
|
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise |
|
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt |
|
return x_prev |
|
|
|
|
|
@torch.no_grad() |
|
def update_ema(target_params, source_params, rate=0.99): |
|
""" |
|
Update target parameters to be closer to those of source parameters using |
|
an exponential moving average. |
|
|
|
:param target_params: the target parameter sequence. |
|
:param source_params: the source parameter sequence. |
|
:param rate: the EMA rate (closer to 1 means slower). |
|
""" |
|
for targ, src in zip(target_params, source_params): |
|
targ.detach().mul_(rate).add_(src, alpha=1 - rate) |
|
|
|
|
|
def convert_lcm_lora(unet, path, alpha=1.0): |
|
|
|
if path.endswith(("ckpt",)): |
|
state_dict = torch.load(path, map_location="cpu") |
|
else: |
|
state_dict = {} |
|
with safe_open(path, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
state_dict[key] = f.get_tensor(key) |
|
|
|
num_alpha = 0 |
|
for key in state_dict.keys(): |
|
if "alpha" in key: |
|
num_alpha += 1 |
|
|
|
lora_keys = [k for k in state_dict.keys( |
|
) if k.endswith("lora_down.weight")] |
|
|
|
updated_state_dict = {} |
|
for key in lora_keys: |
|
lora_name = key.split(".")[0] |
|
|
|
if lora_name.startswith("lora_unet_"): |
|
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") |
|
|
|
if "input.blocks" in diffusers_name: |
|
diffusers_name = diffusers_name.replace( |
|
"input.blocks", "down_blocks") |
|
else: |
|
diffusers_name = diffusers_name.replace( |
|
"down.blocks", "down_blocks") |
|
|
|
if "middle.block" in diffusers_name: |
|
diffusers_name = diffusers_name.replace( |
|
"middle.block", "mid_block") |
|
else: |
|
diffusers_name = diffusers_name.replace( |
|
"mid.block", "mid_block") |
|
if "output.blocks" in diffusers_name: |
|
diffusers_name = diffusers_name.replace( |
|
"output.blocks", "up_blocks") |
|
else: |
|
diffusers_name = diffusers_name.replace( |
|
"up.blocks", "up_blocks") |
|
|
|
diffusers_name = diffusers_name.replace( |
|
"transformer.blocks", "transformer_blocks") |
|
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") |
|
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") |
|
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") |
|
diffusers_name = diffusers_name.replace( |
|
"to.out.0.lora", "to_out_lora") |
|
diffusers_name = diffusers_name.replace("proj.in", "proj_in") |
|
diffusers_name = diffusers_name.replace("proj.out", "proj_out") |
|
diffusers_name = diffusers_name.replace( |
|
"time.emb.proj", "time_emb_proj") |
|
diffusers_name = diffusers_name.replace( |
|
"conv.shortcut", "conv_shortcut") |
|
|
|
updated_state_dict[diffusers_name] = state_dict[key] |
|
up_diffusers_name = diffusers_name.replace(".down.", ".up.") |
|
up_key = key.replace("lora_down.weight", "lora_up.weight") |
|
updated_state_dict[up_diffusers_name] = state_dict[up_key] |
|
|
|
state_dict = updated_state_dict |
|
|
|
num_lora = 0 |
|
for key in state_dict: |
|
if "up." in key: |
|
continue |
|
up_key = key.replace(".down.", ".up.") |
|
model_key = key.replace("processor.", "").replace("_lora", "").replace( |
|
"down.", "").replace("up.", "").replace(".lora", "") |
|
model_key = model_key.replace("to_out.", "to_out.0.") |
|
layer_infos = model_key.split(".")[:-1] |
|
|
|
curr_layer = unet |
|
while len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
|
|
weight_down = state_dict[key].to( |
|
curr_layer.weight.data.device, curr_layer.weight.data.dtype) |
|
weight_up = state_dict[up_key].to( |
|
curr_layer.weight.data.device, curr_layer.weight.data.dtype) |
|
|
|
if weight_up.ndim == 2: |
|
curr_layer.weight.data += 1/8 * alpha * \ |
|
torch.mm(weight_up, weight_down) |
|
else: |
|
assert weight_up.ndim == 4 |
|
curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten( |
|
start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape) |
|
num_lora += 1 |
|
|
|
return unet |
|
|