tokenflow / preprocess_utils.py
Linoy Tsaban
Create preprocess_utils.py
1a2c8b5
raw
history blame
7.72 kB
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
# suppress partial model loading warning
logging.set_verbosity_error()
import os
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import argparse
from torchvision.io import write_video
from pathlib import Path
from util import *
import torchvision.transforms as T
def get_timesteps(scheduler, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
@torch.no_grad()
def decode_latents(pipe, latents):
decoded = []
batch_size = 8
for b in range(0, latents.shape[0], batch_size):
latents_batch = 1 / 0.18215 * latents[b:b + batch_size]
imgs = pipe.vae.decode(latents_batch).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
decoded.append(imgs)
return torch.cat(decoded)
@torch.no_grad()
def ddim_inversion(pipe, cond, latent_frames, batch_size, save_latents=True, timesteps_to_save=None):
timesteps = reversed(pipe.scheduler.timesteps)
timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps
for i, t in enumerate(tqdm(timesteps)):
for b in range(0, latent_frames.shape[0], batch_size):
x_batch = latent_frames[b:b + batch_size]
model_input = x_batch
cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
#remove comment from commented block to support controlnet
# if self.sd_version == 'depth':
# depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
# model_input = torch.cat([x_batch, depth_maps],dim=1)
alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
alpha_prod_t_prev = (
pipe.scheduler.alphas_cumprod[timesteps[i - 1]]
if i > 0 else pipe.scheduler.final_alpha_cumprod
)
mu = alpha_prod_t ** 0.5
mu_prev = alpha_prod_t_prev ** 0.5
sigma = (1 - alpha_prod_t) ** 0.5
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
#remove line below and replace with commented block to support controlnet
eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
# if self.sd_version != 'ControlNet':
# eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
# else:
# eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
pred_x0 = (x_batch - sigma_prev * eps) / mu_prev
latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps
# if save_latents and t in timesteps_to_save:
# torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
# torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
return latent_frames
@torch.no_grad()
def ddim_sample(pipe, x, cond, batch_size):
timesteps = pipe.scheduler.timesteps
for i, t in enumerate(tqdm(timesteps)):
for b in range(0, x.shape[0], batch_size):
x_batch = x[b:b + batch_size]
model_input = x_batch
cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
#remove comment from commented block to support controlnet
# if self.sd_version == 'depth':
# depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
# model_input = torch.cat([x_batch, depth_maps],dim=1)
alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
alpha_prod_t_prev = (
pipe.scheduler.alphas_cumprod[timesteps[i + 1]]
if i < len(timesteps) - 1
else pipe.scheduler.final_alpha_cumprod
)
mu = alpha_prod_t ** 0.5
sigma = (1 - alpha_prod_t) ** 0.5
mu_prev = alpha_prod_t_prev ** 0.5
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
#remove line below and replace with commented block to support controlnet
eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
# if self.sd_version != 'ControlNet':
# eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
# else:
# eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
pred_x0 = (x_batch - sigma * eps) / mu
x[b:b + batch_size] = mu_prev * pred_x0 + sigma_prev * eps
return x
@torch.no_grad()
def get_text_embeds(pipe, prompt, negative_prompt, batch_size=1, device="cuda"):
# Tokenize text and get embeddings
text_input = pipe.tokenizer(prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length,
truncation=True, return_tensors='pt')
text_embeddings = pipe.text_encoder(text_input.input_ids.to(pipe.device))[0]
# Do the same for unconditional embeddings
uncond_input = pipe.tokenizer(negative_prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length,
return_tensors='pt')
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]
# Cat for final embeddings
text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size)
return text_embeddings
@torch.no_grad()
def extract_latents(pipe,
num_steps,
latent_frames,
batch_size,
timesteps_to_save,
inversion_prompt=''):
pipe.scheduler.set_timesteps(num_steps)
cond = get_text_embeds(pipe, inversion_prompt, "", device=pipe.device)[1].unsqueeze(0)
# latent_frames = self.latents
inverted_latents = ddim_inversion(pipe, cond,
latent_frames,
batch_size=batch_size,
save_latents=False,
timesteps_to_save=timesteps_to_save)
# latent_reconstruction = ddim_sample(pipe, inverted_latents, cond, batch_size=batch_size)
# rgb_reconstruction = decode_latents(pipe, latent_reconstruction)
# return rgb_reconstruction
return inverted_latents
@torch.no_grad()
def encode_imgs(pipe, imgs, batch_size=10, deterministic=True):
imgs = 2 * imgs - 1
latents = []
for i in range(0, len(imgs), batch_size):
posterior = pipe.vae.encode(imgs[i:i + batch_size]).latent_dist
latent = posterior.mean if deterministic else posterior.sample()
latents.append(latent * 0.18215)
latents = torch.cat(latents)
return latents
def get_data(pipe, frames, n_frames):
"""
converts frames to tensors, saves to device and encodes to obtain latents
"""
frames = frames[:n_frames]
if frames[0].size[0] == frames[0].size[1]:
frames = [frame.convert("RGB").resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
stacked_tensor_frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(pipe.device)
# encode to latents
latents = encode_imgs(pipe, stacked_tensor_frames, deterministic=True).to(torch.float16).to(pipe.device)
return stacked_tensor_frames, latents