ChronoDepth / chronodepth /chronodepth_pipeline.py
jhaoshao
release v1 demo
861fa04
import inspect
from typing import Union, Optional, List
import torch
import numpy as np
from tqdm.auto import tqdm
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
_resize_with_antialiasing,
StableVideoDiffusionPipelineOutput,
StableVideoDiffusionPipeline,
)
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
from einops import rearrange
class ChronoDepthPipeline(StableVideoDiffusionPipeline):
@torch.inference_mode()
def encode_images(self,
images: torch.Tensor,
decode_chunk_size=5,
):
video_length = images.shape[1]
images = rearrange(images, "b f c h w -> (b f) c h w")
latents = []
for i in range(0, images.shape[0], decode_chunk_size):
latents_chunk = self.vae.encode(images[i : i + decode_chunk_size]).latent_dist.sample()
latents.append(latents_chunk)
latents = torch.cat(latents, dim=0)
latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
latents = latents * self.vae.config.scaling_factor
return latents
@torch.inference_mode()
def _encode_image(self, images, device, discard=True, chunk_size=14):
'''
set image to zero tensor discards the image embeddings if discard is True
'''
dtype = next(self.image_encoder.parameters()).dtype
images = _resize_with_antialiasing(images, (224, 224))
images = (images + 1.0) / 2.0
if discard:
images = torch.zeros_like(images)
image_embeddings = []
for i in range(0, images.shape[0], chunk_size):
tmp = self.feature_extractor(
images=images[i : i + chunk_size],
do_normalize=True,
do_center_crop=False,
do_resize=False,
do_rescale=False,
return_tensors="pt",
).pixel_values
tmp = tmp.to(device=device, dtype=dtype)
image_embeddings.append(self.image_encoder(tmp).image_embeds)
image_embeddings = torch.cat(image_embeddings, dim=0)
image_embeddings = image_embeddings.unsqueeze(1) # [t, 1, 1024]
return image_embeddings
def decode_depth(self, depth_latent: torch.Tensor, decode_chunk_size=5) -> torch.Tensor:
num_frames = depth_latent.shape[1]
depth_latent = rearrange(depth_latent, "b f c h w -> (b f) c h w")
depth_latent = depth_latent / self.vae.config.scaling_factor
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
depth_frames = []
for i in range(0, depth_latent.shape[0], decode_chunk_size):
num_frames_in = depth_latent[i : i + decode_chunk_size].shape[0]
decode_kwargs = {}
if accepts_num_frames:
# we only pass num_frames_in if it's expected
decode_kwargs["num_frames"] = num_frames_in
depth_frame = self.vae.decode(depth_latent[i : i + decode_chunk_size], **decode_kwargs).sample
depth_frames.append(depth_frame)
depth_frames = torch.cat(depth_frames, dim=0)
depth_frames = depth_frames.reshape(-1, num_frames, *depth_frames.shape[1:])
depth_mean = depth_frames.mean(dim=2, keepdim=True)
return depth_mean
@staticmethod
def check_inputs(images, height, width):
if (
not isinstance(images, torch.Tensor)
and not isinstance(images, np.ndarray)
):
raise ValueError(
"`images` has to be of type `torch.Tensor` or `numpy.ndarray` but is"
f" {type(images)}"
)
if height % 64 != 0 or width % 64 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@torch.no_grad()
def __call__(
self,
input_images: Union[np.ndarray, torch.FloatTensor],
height: int = 576,
width: int = 768,
num_inference_steps: int = 10,
fps: int = 7,
motion_bucket_id: int = 127,
noise_aug_strength: float = 0.02,
decode_chunk_size: Optional[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
show_progress_bar: bool = True,
latents: Optional[torch.Tensor] = None,
infer_mode: str = 'ours',
sigma_epsilon: float = -4,
):
"""
Args:
input_images: shape [T, H, W, 3] if np.ndarray or [T, 3, H, W] if torch.FloatTensor, range [0, 1]
height: int, height of the input image
width: int, width of the input image
num_inference_steps: int, number of inference steps
fps: int, frames per second
motion_bucket_id: int, motion bucket id
noise_aug_strength: float, noise augmentation strength
decode_chunk_size: int, decode chunk size
generator: torch.Generator or List[torch.Generator], random number generator
show_progress_bar: bool, show progress bar
"""
assert height >= 0 and width >=0
assert num_inference_steps >=1
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
# 1. Check inputs. Raise error if not correct
self.check_inputs(input_images, height, width)
# 2. Define call parameters
batch_size = 1 # only support batch size 1 for now
device = self._execution_device
# 3. Encode input image
if isinstance(input_images, np.ndarray):
input_images = torch.from_numpy(input_images.transpose(0, 3, 1, 2))
else:
assert isinstance(input_images, torch.Tensor)
input_images = input_images.to(device=device)
input_images = input_images * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
discard_clip_features = True
image_embeddings = self._encode_image(input_images, device,
discard=discard_clip_features,
chunk_size=decode_chunk_size
)
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
# is why it is reduced here.
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
fps = fps - 1
# 4. Encode input image using VAE
noise = randn_tensor(input_images.shape, generator=generator, device=device, dtype=input_images.dtype)
input_images = input_images + noise_aug_strength * noise
rgb_batch = input_images.unsqueeze(0)
added_time_ids = self._get_add_time_ids(
fps,
motion_bucket_id,
noise_aug_strength,
image_embeddings.dtype,
batch_size,
1, # do not modify this!
False, # do not modify this!
)
added_time_ids = added_time_ids.to(device)
if infer_mode == 'ours':
depth_pred_raw = self.single_infer_ours(
rgb_batch,
image_embeddings,
added_time_ids,
num_inference_steps,
show_progress_bar,
generator,
decode_chunk_size=decode_chunk_size,
latents=latents,
sigma_epsilon=sigma_epsilon,
)
elif infer_mode == 'replacement':
depth_pred_raw = self.single_infer_replacement(
rgb_batch,
image_embeddings,
added_time_ids,
num_inference_steps,
show_progress_bar,
generator,
decode_chunk_size=decode_chunk_size,
latents=latents,
)
elif infer_mode == 'naive':
depth_pred_raw = self.single_infer_naive_sliding_window(
rgb_batch,
image_embeddings,
added_time_ids,
num_inference_steps,
show_progress_bar,
generator,
decode_chunk_size=decode_chunk_size,
latents=latents,
)
depth_frames = depth_pred_raw.cpu().numpy().astype(np.float32)
self.maybe_free_model_hooks()
return StableVideoDiffusionPipelineOutput(
frames = depth_frames,
)
@torch.no_grad()
def single_infer_ours(self,
input_rgb: torch.Tensor,
image_embeddings: torch.Tensor,
added_time_ids: torch.Tensor,
num_inference_steps: int,
show_pbar: bool,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
decode_chunk_size=1,
latents: Optional[torch.Tensor] = None,
sigma_epsilon: float = -4,
):
device = input_rgb.device
H, W = input_rgb.shape[-2:]
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
rgb_latent = self.encode_images(input_rgb)
rgb_latent = rgb_latent.to(image_embeddings.dtype)
torch.cuda.empty_cache()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
batch_size, n_frames, _, _, _ = rgb_latent.shape
num_channels_latents = self.unet.config.in_channels
curr_frame = 0
depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device)
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
# first chunk
horizon = min(n_frames-curr_frame, self.n_tokens)
start_frame = 0
chunk = self.prepare_latents(
batch_size,
horizon,
num_channels_latents,
H,
W,
image_embeddings.dtype,
device,
generator,
latents,
)
depth_latent = torch.cat([depth_latent, chunk], 1)
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising first chunk",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
curr_timesteps = torch.tensor([t]*horizon).to(device)
depth_latent = self.scheduler.scale_model_input(depth_latent, t)
noise_pred = self.unet(
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
curr_timesteps[start_frame:],
image_embeddings[start_frame:curr_frame+horizon],
added_time_ids=added_time_ids
)[0]
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
self.scheduler._step_index = None
curr_frame += horizon
pbar.update(horizon)
while curr_frame < n_frames:
if self.chunk_size > 0:
horizon = min(n_frames - curr_frame, self.chunk_size)
else:
horizon = min(n_frames - curr_frame, self.n_tokens)
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
chunk = self.prepare_latents(
batch_size,
horizon,
num_channels_latents,
H,
W,
image_embeddings.dtype,
device,
generator,
latents,
)
depth_latent = torch.cat([depth_latent, chunk], 1)
start_frame = max(0, curr_frame + horizon - self.n_tokens)
pbar.set_postfix(
{
"start": start_frame,
"end": curr_frame + horizon,
}
)
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising ",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
t_horizon = torch.tensor([t]*horizon).to(device)
# t_context = timesteps[-1] * torch.ones((curr_frame,), dtype=t.dtype).to(device)
t_context = sigma_epsilon * torch.ones((curr_frame,), dtype=t.dtype).to(device)
curr_timesteps = torch.concatenate((t_context, t_horizon), 0)
depth_latent[:, curr_frame:] = self.scheduler.scale_model_input(depth_latent[:, curr_frame:], t)
noise_pred = self.unet(
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
curr_timesteps[start_frame:],
image_embeddings[start_frame:curr_frame+horizon],
added_time_ids=added_time_ids
)[0]
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
self.scheduler._step_index = None
curr_frame += horizon
pbar.update(horizon)
torch.cuda.empty_cache()
if needs_upcasting:
self.vae.to(dtype=torch.float16)
depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
# clip prediction
depth = torch.clip(depth, -1.0, 1.0)
# shift to [0, 1]
depth = (depth + 1.0) / 2.0
return depth.squeeze(0)
@torch.no_grad()
def single_infer_replacement(self,
input_rgb: torch.Tensor,
image_embeddings: torch.Tensor,
added_time_ids: torch.Tensor,
num_inference_steps: int,
show_pbar: bool,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
decode_chunk_size=1,
latents: Optional[torch.Tensor] = None,
):
device = input_rgb.device
H, W = input_rgb.shape[-2:]
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
rgb_latent = self.encode_images(input_rgb)
rgb_latent = rgb_latent.to(image_embeddings.dtype)
torch.cuda.empty_cache()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
batch_size, n_frames, _, _, _ = rgb_latent.shape
num_channels_latents = self.unet.config.in_channels
curr_frame = 0
depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device)
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
# first chunk
horizon = min(n_frames-curr_frame, self.n_tokens)
start_frame = 0
chunk = self.prepare_latents(
batch_size,
horizon,
num_channels_latents,
H,
W,
image_embeddings.dtype,
device,
generator,
latents,
)
depth_latent = torch.cat([depth_latent, chunk], 1)
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising first chunk",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
curr_timesteps = torch.tensor([t]*horizon).to(device)
depth_latent = self.scheduler.scale_model_input(depth_latent, t)
noise_pred = self.unet(
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
curr_timesteps[start_frame:],
image_embeddings[start_frame:curr_frame+horizon],
added_time_ids=added_time_ids
)[0]
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
self.scheduler._step_index = None
curr_frame += horizon
pbar.update(horizon)
while curr_frame < n_frames:
if self.chunk_size > 0:
horizon = min(n_frames - curr_frame, self.chunk_size)
else:
horizon = min(n_frames - curr_frame, self.n_tokens)
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
chunk = self.prepare_latents(
batch_size,
horizon,
num_channels_latents,
H,
W,
image_embeddings.dtype,
device,
generator,
latents,
)
depth_latent = torch.cat([depth_latent, chunk], 1)
start_frame = max(0, curr_frame + horizon - self.n_tokens)
depth_pred_last_latent = depth_latent[:, start_frame:curr_frame].clone()
pbar.set_postfix(
{
"start": start_frame,
"end": curr_frame + horizon,
}
)
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising ",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device)
epsilon = randn_tensor(
depth_pred_last_latent.shape,
generator=generator,
device=device,
dtype=image_embeddings.dtype
)
depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent + epsilon * self.scheduler.sigmas[i]
depth_latent[:, start_frame:] = self.scheduler.scale_model_input(depth_latent[:, start_frame:], t)
noise_pred = self.unet(
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
curr_timesteps,
image_embeddings[start_frame:curr_frame+horizon],
added_time_ids=added_time_ids
)[0]
depth_latent[:, start_frame:] = self.scheduler.step(noise_pred, t, depth_latent[:, start_frame:]).prev_sample
depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent
self.scheduler._step_index = None
curr_frame += horizon
pbar.update(horizon)
torch.cuda.empty_cache()
if needs_upcasting:
self.vae.to(dtype=torch.float16)
depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
# clip prediction
depth = torch.clip(depth, -1.0, 1.0)
# shift to [0, 1]
depth = (depth + 1.0) / 2.0
return depth.squeeze(0)
@torch.no_grad()
def single_infer_naive_sliding_window(self,
input_rgb: torch.Tensor,
image_embeddings: torch.Tensor,
added_time_ids: torch.Tensor,
num_inference_steps: int,
show_pbar: bool,
generator: Optional[Union[torch.Generator, List[torch.Generator]]],
decode_chunk_size=1,
latents: Optional[torch.Tensor] = None,
):
device = input_rgb.device
H, W = input_rgb.shape[-2:]
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.vae.to(dtype=torch.float32)
rgb_latent = self.encode_images(input_rgb)
rgb_latent = rgb_latent.to(image_embeddings.dtype)
torch.cuda.empty_cache()
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
batch_size, n_frames, _, _, _ = rgb_latent.shape
num_channels_latents = self.unet.config.in_channels
curr_frame = 0
depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device)
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
# first chunk
horizon = min(n_frames-curr_frame, self.n_tokens)
start_frame = 0
chunk = self.prepare_latents(
batch_size,
horizon,
num_channels_latents,
H,
W,
image_embeddings.dtype,
device,
generator,
latents,
)
depth_latent = torch.cat([depth_latent, chunk], 1)
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising first chunk",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
curr_timesteps = torch.tensor([t]*horizon).to(device)
depth_latent = self.scheduler.scale_model_input(depth_latent, t)
noise_pred = self.unet(
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2),
curr_timesteps[start_frame:],
image_embeddings[start_frame:curr_frame+horizon],
added_time_ids=added_time_ids
)[0]
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample
self.scheduler._step_index = None
curr_frame += horizon
pbar.update(horizon)
while curr_frame < n_frames:
if self.chunk_size > 0:
horizon = min(n_frames - curr_frame, self.chunk_size)
else:
horizon = min(n_frames - curr_frame, self.n_tokens)
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
start_frame = max(0, curr_frame + horizon - self.n_tokens)
chunk = self.prepare_latents(
batch_size,
curr_frame+horizon-start_frame,
num_channels_latents,
H,
W,
image_embeddings.dtype,
device,
generator,
latents,
)
pbar.set_postfix(
{
"start": start_frame,
"end": curr_frame + horizon,
}
)
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising ",
)
else:
iterable = enumerate(timesteps)
for i, t in iterable:
curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device)
chunk = self.scheduler.scale_model_input(chunk, t)
noise_pred = self.unet(
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], chunk], dim=2),
curr_timesteps,
image_embeddings[start_frame:curr_frame+horizon],
added_time_ids=added_time_ids
)[0]
chunk = self.scheduler.step(noise_pred, t, chunk).prev_sample
depth_latent = torch.cat([depth_latent, chunk[:, -horizon:]], 1)
self.scheduler._step_index = None
curr_frame += horizon
pbar.update(horizon)
torch.cuda.empty_cache()
if needs_upcasting:
self.vae.to(dtype=torch.float16)
depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size)
# clip prediction
depth = torch.clip(depth, -1.0, 1.0)
# shift to [0, 1]
depth = (depth + 1.0) / 2.0
return depth.squeeze(0)