Spaces:
Runtime error
Runtime error
from transformers import CLIPTextModel, CLIPTokenizer, logging | |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler | |
# suppress partial model loading warning | |
logging.set_verbosity_error() | |
import os | |
from tqdm import tqdm, trange | |
import torch | |
import torch.nn as nn | |
import argparse | |
from torchvision.io import write_video | |
from pathlib import Path | |
from util import * | |
import torchvision.transforms as T | |
def get_timesteps(scheduler, num_inference_steps, strength, device): | |
# get the original timestep using init_timestep | |
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | |
t_start = max(num_inference_steps - init_timestep, 0) | |
timesteps = scheduler.timesteps[t_start:] | |
return timesteps, num_inference_steps - t_start | |
def decode_latents(pipe, latents): | |
decoded = [] | |
batch_size = 8 | |
for b in range(0, latents.shape[0], batch_size): | |
latents_batch = 1 / 0.18215 * latents[b:b + batch_size] | |
imgs = pipe.vae.decode(latents_batch).sample | |
imgs = (imgs / 2 + 0.5).clamp(0, 1) | |
decoded.append(imgs) | |
return torch.cat(decoded) | |
def ddim_inversion(pipe, cond, latent_frames, batch_size, save_latents=True, timesteps_to_save=None): | |
timesteps = reversed(pipe.scheduler.timesteps) | |
timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps | |
for i, t in enumerate(tqdm(timesteps)): | |
for b in range(0, latent_frames.shape[0], batch_size): | |
x_batch = latent_frames[b:b + batch_size] | |
model_input = x_batch | |
cond_batch = cond.repeat(x_batch.shape[0], 1, 1) | |
#remove comment from commented block to support controlnet | |
# if self.sd_version == 'depth': | |
# depth_maps = torch.cat([self.depth_maps[b: b + batch_size]]) | |
# model_input = torch.cat([x_batch, depth_maps],dim=1) | |
alpha_prod_t = pipe.scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
pipe.scheduler.alphas_cumprod[timesteps[i - 1]] | |
if i > 0 else pipe.scheduler.final_alpha_cumprod | |
) | |
mu = alpha_prod_t ** 0.5 | |
mu_prev = alpha_prod_t_prev ** 0.5 | |
sigma = (1 - alpha_prod_t) ** 0.5 | |
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 | |
#remove line below and replace with commented block to support controlnet | |
eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample | |
# if self.sd_version != 'ControlNet': | |
# eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample | |
# else: | |
# eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]])) | |
pred_x0 = (x_batch - sigma_prev * eps) / mu_prev | |
latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps | |
# if save_latents and t in timesteps_to_save: | |
# torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt')) | |
# torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt')) | |
return latent_frames | |
def ddim_sample(pipe, x, cond, batch_size): | |
timesteps = pipe.scheduler.timesteps | |
for i, t in enumerate(tqdm(timesteps)): | |
for b in range(0, x.shape[0], batch_size): | |
x_batch = x[b:b + batch_size] | |
model_input = x_batch | |
cond_batch = cond.repeat(x_batch.shape[0], 1, 1) | |
#remove comment from commented block to support controlnet | |
# if self.sd_version == 'depth': | |
# depth_maps = torch.cat([self.depth_maps[b: b + batch_size]]) | |
# model_input = torch.cat([x_batch, depth_maps],dim=1) | |
alpha_prod_t = pipe.scheduler.alphas_cumprod[t] | |
alpha_prod_t_prev = ( | |
pipe.scheduler.alphas_cumprod[timesteps[i + 1]] | |
if i < len(timesteps) - 1 | |
else pipe.scheduler.final_alpha_cumprod | |
) | |
mu = alpha_prod_t ** 0.5 | |
sigma = (1 - alpha_prod_t) ** 0.5 | |
mu_prev = alpha_prod_t_prev ** 0.5 | |
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 | |
#remove line below and replace with commented block to support controlnet | |
eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample | |
# if self.sd_version != 'ControlNet': | |
# eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample | |
# else: | |
# eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]])) | |
pred_x0 = (x_batch - sigma * eps) / mu | |
x[b:b + batch_size] = mu_prev * pred_x0 + sigma_prev * eps | |
return x | |
def get_text_embeds(pipe, prompt, negative_prompt, batch_size=1, device="cuda"): | |
# Tokenize text and get embeddings | |
text_input = pipe.tokenizer(prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length, | |
truncation=True, return_tensors='pt') | |
text_embeddings = pipe.text_encoder(text_input.input_ids.to(pipe.device))[0] | |
# Do the same for unconditional embeddings | |
uncond_input = pipe.tokenizer(negative_prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length, | |
return_tensors='pt') | |
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0] | |
# Cat for final embeddings | |
text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size) | |
return text_embeddings | |
def extract_latents(pipe, | |
num_steps, | |
latent_frames, | |
batch_size, | |
timesteps_to_save, | |
inversion_prompt=''): | |
pipe.scheduler.set_timesteps(num_steps) | |
cond = get_text_embeds(pipe, inversion_prompt, "", device=pipe.device)[1].unsqueeze(0) | |
# latent_frames = self.latents | |
inverted_latents = ddim_inversion(pipe, cond, | |
latent_frames, | |
batch_size=batch_size, | |
save_latents=False, | |
timesteps_to_save=timesteps_to_save) | |
# latent_reconstruction = ddim_sample(pipe, inverted_latents, cond, batch_size=batch_size) | |
# rgb_reconstruction = decode_latents(pipe, latent_reconstruction) | |
# return rgb_reconstruction | |
return inverted_latents | |
def encode_imgs(pipe, imgs, batch_size=10, deterministic=True): | |
imgs = 2 * imgs - 1 | |
latents = [] | |
for i in range(0, len(imgs), batch_size): | |
posterior = pipe.vae.encode(imgs[i:i + batch_size]).latent_dist | |
latent = posterior.mean if deterministic else posterior.sample() | |
latents.append(latent * 0.18215) | |
latents = torch.cat(latents) | |
return latents | |
def get_data(pipe, frames, n_frames): | |
""" | |
converts frames to tensors, saves to device and encodes to obtain latents | |
""" | |
frames = frames[:n_frames] | |
if frames[0].size[0] == frames[0].size[1]: | |
frames = [frame.convert("RGB").resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames] | |
stacked_tensor_frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(pipe.device) | |
# encode to latents | |
latents = encode_imgs(pipe, stacked_tensor_frames, deterministic=True).to(torch.float16).to(pipe.device) | |
return stacked_tensor_frames, latents | |