FRESCO / src /pipe_FRESCO.py
SingleZombie
upload files
ff715ca
from src.utils import *
from src.flow_utils import warp_tensor
import torch
import torchvision
import gc
"""
==========================================================================
* step(): one DDPM step with background smoothing
* inference(): translate one batch with FRESCO and background smoothing
==========================================================================
"""
def step(pipe, model_output, timestep, sample, generator, repeat_noise=False,
visualize_pipeline=False, flows=None, occs=None, saliency=None):
"""
DDPM step with background smoothing
* background smoothing: warp the background region of the previous frame to the current frame
"""
scheduler = pipe.scheduler
# 1. get previous step value (=t-1)
prev_timestep = scheduler.previous_timestep(timestep)
# 2. compute alphas, betas
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
current_beta_t = 1 - current_alpha_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
"""
[HACK] add background smoothing
decode the feature
warp the feature of f_{i-1}
fuse the warped f_{i-1} with f_{i} in the non-salient region (i.e., background)
encode the fused feature
"""
if saliency is not None and flows is not None and occs is not None:
image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample
image = warp_tensor(image, flows, occs, saliency, unet_chunk_size=1)
pred_original_sample = pipe.vae.config.scaling_factor * pipe.vae.encode(image).latent_dist.sample()
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
variance = beta_prod_t_prev / beta_prod_t * current_beta_t
variance = torch.clamp(variance, min=1e-20)
variance = (variance ** 0.5) * torch.randn(model_output.shape, generator=generator,
device=model_output.device, dtype=model_output.dtype)
"""
[HACK] background smoothing
applying the same noise could be good for static background
"""
if repeat_noise:
variance = variance[0:1].repeat(model_output.shape[0],1,1,1)
if visualize_pipeline: # for debug
image = pipe.vae.decode(pred_original_sample / pipe.vae.config.scaling_factor).sample
viz = torchvision.utils.make_grid(torch.clamp(image, -1, 1), image.shape[0], 1)
visualize(viz.cpu(), 90)
pred_prev_sample = pred_prev_sample + variance
return (pred_prev_sample, pred_original_sample)
@torch.no_grad()
def inference(pipe, controlnet, frescoProc,
imgs, prompt_embeds, edges, timesteps,
cond_scale=[0.7]*20, num_inference_steps=20, num_warmup_steps=6,
do_classifier_free_guidance=True, seed=0, guidance_scale=7.5, use_controlnet=True,
record_latents=[], propagation_mode=False, visualize_pipeline=False,
flows = None, occs = None, saliency=None, repeat_noise=False,
num_intraattn_steps = 1, step_interattn_end = 350, bg_smoothing_steps = [16,17]):
"""
video-to-video translation inference pipeline with FRESCO
* add controlnet and SDEdit
* add FRESCO-guided attention
* add FRESCO-guided optimization
* add background smoothing
* add support for inter-batch long video translation
[input of the original pipe]
pipe: base diffusion model
imgs: a batch of the input frames
prompt_embeds: prompts
num_inference_steps: number of DDPM steps
timesteps: generated by pipe.scheduler.set_timesteps(num_inference_steps)
do_classifier_free_guidance: cfg, should be always true
guidance_scale: cfg scale
seed
[input of SDEdit]
num_warmup_steps: skip the first num_warmup_steps DDPM steps
[input of controlnet]
use_controlnet: bool, whether using controlnet
controlnet: controlnet model
edges: input for controlnet (edge/stroke/depth, etc.)
cond_scale: controlnet scale
[input of FRESCO]
frescoProc: FRESCO attention controller
flows: optical flows
occs: occlusion mask
num_intraattn_steps: apply num_interattn_steps steps of spatial-guided attention
step_interattn_end: apply temporal-guided attention in [step_interattn_end, 1000] steps
[input for background smoothing]
saliency: saliency mask
repeat_noise: bool, use the same noise for all frames
bg_smoothing_steps: apply background smoothing in bg_smoothing_steps
[input for long video translation]
record_latents: recorded latents in the last batch
propagation_mode: bool, whether this is not the first batch
[output]
latents: a batch of latents of the translated frames
"""
gc.collect()
torch.cuda.empty_cache()
device = pipe._execution_device
noise_scheduler = pipe.scheduler
generator = torch.Generator(device=device).manual_seed(seed)
B, C, H, W = imgs.shape
latents = pipe.prepare_latents(
B,
pipe.unet.config.in_channels,
H,
W,
prompt_embeds.dtype,
device,
generator,
latents = None,
)
if repeat_noise:
latents = latents[0:1].repeat(B,1,1,1).detach()
if num_warmup_steps < 0:
latents_init = latents.detach()
num_warmup_steps = 0
else:
# SDEdit, use the noisy latent of imges as the input rather than a pure gausssian noise
latent_x0 = pipe.vae.config.scaling_factor * pipe.vae.encode(imgs.to(pipe.unet.dtype)).latent_dist.sample()
latents_init = noise_scheduler.add_noise(latent_x0, latents, timesteps[num_warmup_steps]).detach()
# SDEdit, run num_inference_steps-num_warmup_steps steps
with pipe.progress_bar(total=num_inference_steps-num_warmup_steps) as progress_bar:
latents = latents_init
for i, t in enumerate(timesteps[num_warmup_steps:]):
"""
[HACK] control the steps to apply spatial/temporal-guided attention
[HACK] record and restore latents from previous batch
"""
if i >= num_intraattn_steps:
frescoProc.controller.disable_intraattn()
if t < step_interattn_end:
frescoProc.controller.disable_interattn()
if propagation_mode: # restore latent from previous batch and record latent of the current batch
latents[0:2] = record_latents[i].detach().clone()
record_latents[i] = latents[[0,len(latents)-1]].detach().clone()
else: # frist batch, record_latents[0][t] = [x_1,t, x_{N,t}]
record_latents += [latents[[0,len(latents)-1]].detach().clone()]
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if use_controlnet:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=edges,
conditioning_scale=cond_scale[i+num_warmup_steps],
guess_mode=False,
return_dict=False,
)
else:
down_block_res_samples, mid_block_res_sample = None, None
# predict the noise residual
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
"""
[HACK] background smoothing
Note: bg_smoothing_steps should be rescaled based on num_inference_steps
current [16,17] is based on num_inference_steps=20
"""
if i + num_warmup_steps in bg_smoothing_steps:
latents = step(pipe, noise_pred, t, latents, generator,
visualize_pipeline=visualize_pipeline,
flows = flows, occs = occs, saliency=saliency)[0]
else:
latents = step(pipe, noise_pred, t, latents, generator,
visualize_pipeline=visualize_pipeline)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % pipe.scheduler.order == 0):
progress_bar.update()
return latents