Spaces:
Running
on
Zero
Running
on
Zero
import inspect | |
from typing import Union, Optional, List | |
import torch | |
import numpy as np | |
from tqdm.auto import tqdm | |
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( | |
_resize_with_antialiasing, | |
StableVideoDiffusionPipelineOutput, | |
StableVideoDiffusionPipeline, | |
) | |
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor | |
from einops import rearrange | |
class ChronoDepthPipeline(StableVideoDiffusionPipeline): | |
def encode_images(self, | |
images: torch.Tensor, | |
decode_chunk_size=5, | |
): | |
video_length = images.shape[1] | |
images = rearrange(images, "b f c h w -> (b f) c h w") | |
latents = [] | |
for i in range(0, images.shape[0], decode_chunk_size): | |
latents_chunk = self.vae.encode(images[i : i + decode_chunk_size]).latent_dist.sample() | |
latents.append(latents_chunk) | |
latents = torch.cat(latents, dim=0) | |
latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length) | |
latents = latents * self.vae.config.scaling_factor | |
return latents | |
def _encode_image(self, images, device, discard=True, chunk_size=14): | |
''' | |
set image to zero tensor discards the image embeddings if discard is True | |
''' | |
dtype = next(self.image_encoder.parameters()).dtype | |
images = _resize_with_antialiasing(images, (224, 224)) | |
images = (images + 1.0) / 2.0 | |
if discard: | |
images = torch.zeros_like(images) | |
image_embeddings = [] | |
for i in range(0, images.shape[0], chunk_size): | |
tmp = self.feature_extractor( | |
images=images[i : i + chunk_size], | |
do_normalize=True, | |
do_center_crop=False, | |
do_resize=False, | |
do_rescale=False, | |
return_tensors="pt", | |
).pixel_values | |
tmp = tmp.to(device=device, dtype=dtype) | |
image_embeddings.append(self.image_encoder(tmp).image_embeds) | |
image_embeddings = torch.cat(image_embeddings, dim=0) | |
image_embeddings = image_embeddings.unsqueeze(1) # [t, 1, 1024] | |
return image_embeddings | |
def decode_depth(self, depth_latent: torch.Tensor, decode_chunk_size=5) -> torch.Tensor: | |
num_frames = depth_latent.shape[1] | |
depth_latent = rearrange(depth_latent, "b f c h w -> (b f) c h w") | |
depth_latent = depth_latent / self.vae.config.scaling_factor | |
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward | |
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) | |
depth_frames = [] | |
for i in range(0, depth_latent.shape[0], decode_chunk_size): | |
num_frames_in = depth_latent[i : i + decode_chunk_size].shape[0] | |
decode_kwargs = {} | |
if accepts_num_frames: | |
# we only pass num_frames_in if it's expected | |
decode_kwargs["num_frames"] = num_frames_in | |
depth_frame = self.vae.decode(depth_latent[i : i + decode_chunk_size], **decode_kwargs).sample | |
depth_frames.append(depth_frame) | |
depth_frames = torch.cat(depth_frames, dim=0) | |
depth_frames = depth_frames.reshape(-1, num_frames, *depth_frames.shape[1:]) | |
depth_mean = depth_frames.mean(dim=2, keepdim=True) | |
return depth_mean | |
def check_inputs(images, height, width): | |
if ( | |
not isinstance(images, torch.Tensor) | |
and not isinstance(images, np.ndarray) | |
): | |
raise ValueError( | |
"`images` has to be of type `torch.Tensor` or `numpy.ndarray` but is" | |
f" {type(images)}" | |
) | |
if height % 64 != 0 or width % 64 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
def __call__( | |
self, | |
input_images: Union[np.ndarray, torch.FloatTensor], | |
height: int = 576, | |
width: int = 768, | |
num_inference_steps: int = 10, | |
fps: int = 7, | |
motion_bucket_id: int = 127, | |
noise_aug_strength: float = 0.02, | |
decode_chunk_size: Optional[int] = None, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
show_progress_bar: bool = True, | |
latents: Optional[torch.Tensor] = None, | |
infer_mode: str = 'ours', | |
sigma_epsilon: float = -4, | |
): | |
""" | |
Args: | |
input_images: shape [T, H, W, 3] if np.ndarray or [T, 3, H, W] if torch.FloatTensor, range [0, 1] | |
height: int, height of the input image | |
width: int, width of the input image | |
num_inference_steps: int, number of inference steps | |
fps: int, frames per second | |
motion_bucket_id: int, motion bucket id | |
noise_aug_strength: float, noise augmentation strength | |
decode_chunk_size: int, decode chunk size | |
generator: torch.Generator or List[torch.Generator], random number generator | |
show_progress_bar: bool, show progress bar | |
""" | |
assert height >= 0 and width >=0 | |
assert num_inference_steps >=1 | |
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8 | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs(input_images, height, width) | |
# 2. Define call parameters | |
batch_size = 1 # only support batch size 1 for now | |
device = self._execution_device | |
# 3. Encode input image | |
if isinstance(input_images, np.ndarray): | |
input_images = torch.from_numpy(input_images.transpose(0, 3, 1, 2)) | |
else: | |
assert isinstance(input_images, torch.Tensor) | |
input_images = input_images.to(device=device) | |
input_images = input_images * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w] | |
discard_clip_features = True | |
image_embeddings = self._encode_image(input_images, device, | |
discard=discard_clip_features, | |
chunk_size=decode_chunk_size | |
) | |
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which | |
# is why it is reduced here. | |
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 | |
fps = fps - 1 | |
# 4. Encode input image using VAE | |
noise = randn_tensor(input_images.shape, generator=generator, device=device, dtype=input_images.dtype) | |
input_images = input_images + noise_aug_strength * noise | |
rgb_batch = input_images.unsqueeze(0) | |
added_time_ids = self._get_add_time_ids( | |
fps, | |
motion_bucket_id, | |
noise_aug_strength, | |
image_embeddings.dtype, | |
batch_size, | |
1, # do not modify this! | |
False, # do not modify this! | |
) | |
added_time_ids = added_time_ids.to(device) | |
if infer_mode == 'ours': | |
depth_pred_raw = self.single_infer_ours( | |
rgb_batch, | |
image_embeddings, | |
added_time_ids, | |
num_inference_steps, | |
show_progress_bar, | |
generator, | |
decode_chunk_size=decode_chunk_size, | |
latents=latents, | |
sigma_epsilon=sigma_epsilon, | |
) | |
elif infer_mode == 'replacement': | |
depth_pred_raw = self.single_infer_replacement( | |
rgb_batch, | |
image_embeddings, | |
added_time_ids, | |
num_inference_steps, | |
show_progress_bar, | |
generator, | |
decode_chunk_size=decode_chunk_size, | |
latents=latents, | |
) | |
elif infer_mode == 'naive': | |
depth_pred_raw = self.single_infer_naive_sliding_window( | |
rgb_batch, | |
image_embeddings, | |
added_time_ids, | |
num_inference_steps, | |
show_progress_bar, | |
generator, | |
decode_chunk_size=decode_chunk_size, | |
latents=latents, | |
) | |
depth_frames = depth_pred_raw.cpu().numpy().astype(np.float32) | |
self.maybe_free_model_hooks() | |
return StableVideoDiffusionPipelineOutput( | |
frames = depth_frames, | |
) | |
def single_infer_ours(self, | |
input_rgb: torch.Tensor, | |
image_embeddings: torch.Tensor, | |
added_time_ids: torch.Tensor, | |
num_inference_steps: int, | |
show_pbar: bool, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
decode_chunk_size=1, | |
latents: Optional[torch.Tensor] = None, | |
sigma_epsilon: float = -4, | |
): | |
device = input_rgb.device | |
H, W = input_rgb.shape[-2:] | |
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float32) | |
rgb_latent = self.encode_images(input_rgb) | |
rgb_latent = rgb_latent.to(image_embeddings.dtype) | |
torch.cuda.empty_cache() | |
# cast back to fp16 if needed | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
# Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
batch_size, n_frames, _, _, _ = rgb_latent.shape | |
num_channels_latents = self.unet.config.in_channels | |
curr_frame = 0 | |
depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device) | |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") | |
# first chunk | |
horizon = min(n_frames-curr_frame, self.n_tokens) | |
start_frame = 0 | |
chunk = self.prepare_latents( | |
batch_size, | |
horizon, | |
num_channels_latents, | |
H, | |
W, | |
image_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) | |
depth_latent = torch.cat([depth_latent, chunk], 1) | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising first chunk", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
curr_timesteps = torch.tensor([t]*horizon).to(device) | |
depth_latent = self.scheduler.scale_model_input(depth_latent, t) | |
noise_pred = self.unet( | |
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
curr_timesteps[start_frame:], | |
image_embeddings[start_frame:curr_frame+horizon], | |
added_time_ids=added_time_ids | |
)[0] | |
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
self.scheduler._step_index = None | |
curr_frame += horizon | |
pbar.update(horizon) | |
while curr_frame < n_frames: | |
if self.chunk_size > 0: | |
horizon = min(n_frames - curr_frame, self.chunk_size) | |
else: | |
horizon = min(n_frames - curr_frame, self.n_tokens) | |
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." | |
chunk = self.prepare_latents( | |
batch_size, | |
horizon, | |
num_channels_latents, | |
H, | |
W, | |
image_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) | |
depth_latent = torch.cat([depth_latent, chunk], 1) | |
start_frame = max(0, curr_frame + horizon - self.n_tokens) | |
pbar.set_postfix( | |
{ | |
"start": start_frame, | |
"end": curr_frame + horizon, | |
} | |
) | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising ", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
t_horizon = torch.tensor([t]*horizon).to(device) | |
# t_context = timesteps[-1] * torch.ones((curr_frame,), dtype=t.dtype).to(device) | |
t_context = sigma_epsilon * torch.ones((curr_frame,), dtype=t.dtype).to(device) | |
curr_timesteps = torch.concatenate((t_context, t_horizon), 0) | |
depth_latent[:, curr_frame:] = self.scheduler.scale_model_input(depth_latent[:, curr_frame:], t) | |
noise_pred = self.unet( | |
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
curr_timesteps[start_frame:], | |
image_embeddings[start_frame:curr_frame+horizon], | |
added_time_ids=added_time_ids | |
)[0] | |
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
self.scheduler._step_index = None | |
curr_frame += horizon | |
pbar.update(horizon) | |
torch.cuda.empty_cache() | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size) | |
# clip prediction | |
depth = torch.clip(depth, -1.0, 1.0) | |
# shift to [0, 1] | |
depth = (depth + 1.0) / 2.0 | |
return depth.squeeze(0) | |
def single_infer_replacement(self, | |
input_rgb: torch.Tensor, | |
image_embeddings: torch.Tensor, | |
added_time_ids: torch.Tensor, | |
num_inference_steps: int, | |
show_pbar: bool, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
decode_chunk_size=1, | |
latents: Optional[torch.Tensor] = None, | |
): | |
device = input_rgb.device | |
H, W = input_rgb.shape[-2:] | |
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float32) | |
rgb_latent = self.encode_images(input_rgb) | |
rgb_latent = rgb_latent.to(image_embeddings.dtype) | |
torch.cuda.empty_cache() | |
# cast back to fp16 if needed | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
# Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
batch_size, n_frames, _, _, _ = rgb_latent.shape | |
num_channels_latents = self.unet.config.in_channels | |
curr_frame = 0 | |
depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device) | |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") | |
# first chunk | |
horizon = min(n_frames-curr_frame, self.n_tokens) | |
start_frame = 0 | |
chunk = self.prepare_latents( | |
batch_size, | |
horizon, | |
num_channels_latents, | |
H, | |
W, | |
image_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) | |
depth_latent = torch.cat([depth_latent, chunk], 1) | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising first chunk", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
curr_timesteps = torch.tensor([t]*horizon).to(device) | |
depth_latent = self.scheduler.scale_model_input(depth_latent, t) | |
noise_pred = self.unet( | |
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
curr_timesteps[start_frame:], | |
image_embeddings[start_frame:curr_frame+horizon], | |
added_time_ids=added_time_ids | |
)[0] | |
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
self.scheduler._step_index = None | |
curr_frame += horizon | |
pbar.update(horizon) | |
while curr_frame < n_frames: | |
if self.chunk_size > 0: | |
horizon = min(n_frames - curr_frame, self.chunk_size) | |
else: | |
horizon = min(n_frames - curr_frame, self.n_tokens) | |
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." | |
chunk = self.prepare_latents( | |
batch_size, | |
horizon, | |
num_channels_latents, | |
H, | |
W, | |
image_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) | |
depth_latent = torch.cat([depth_latent, chunk], 1) | |
start_frame = max(0, curr_frame + horizon - self.n_tokens) | |
depth_pred_last_latent = depth_latent[:, start_frame:curr_frame].clone() | |
pbar.set_postfix( | |
{ | |
"start": start_frame, | |
"end": curr_frame + horizon, | |
} | |
) | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising ", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device) | |
epsilon = randn_tensor( | |
depth_pred_last_latent.shape, | |
generator=generator, | |
device=device, | |
dtype=image_embeddings.dtype | |
) | |
depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent + epsilon * self.scheduler.sigmas[i] | |
depth_latent[:, start_frame:] = self.scheduler.scale_model_input(depth_latent[:, start_frame:], t) | |
noise_pred = self.unet( | |
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
curr_timesteps, | |
image_embeddings[start_frame:curr_frame+horizon], | |
added_time_ids=added_time_ids | |
)[0] | |
depth_latent[:, start_frame:] = self.scheduler.step(noise_pred, t, depth_latent[:, start_frame:]).prev_sample | |
depth_latent[:, start_frame:curr_frame] = depth_pred_last_latent | |
self.scheduler._step_index = None | |
curr_frame += horizon | |
pbar.update(horizon) | |
torch.cuda.empty_cache() | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size) | |
# clip prediction | |
depth = torch.clip(depth, -1.0, 1.0) | |
# shift to [0, 1] | |
depth = (depth + 1.0) / 2.0 | |
return depth.squeeze(0) | |
def single_infer_naive_sliding_window(self, | |
input_rgb: torch.Tensor, | |
image_embeddings: torch.Tensor, | |
added_time_ids: torch.Tensor, | |
num_inference_steps: int, | |
show_pbar: bool, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]], | |
decode_chunk_size=1, | |
latents: Optional[torch.Tensor] = None, | |
): | |
device = input_rgb.device | |
H, W = input_rgb.shape[-2:] | |
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float32) | |
rgb_latent = self.encode_images(input_rgb) | |
rgb_latent = rgb_latent.to(image_embeddings.dtype) | |
torch.cuda.empty_cache() | |
# cast back to fp16 if needed | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
# Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
batch_size, n_frames, _, _, _ = rgb_latent.shape | |
num_channels_latents = self.unet.config.in_channels | |
curr_frame = 0 | |
depth_latent = torch.tensor([], dtype=image_embeddings.dtype, device=device) | |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") | |
# first chunk | |
horizon = min(n_frames-curr_frame, self.n_tokens) | |
start_frame = 0 | |
chunk = self.prepare_latents( | |
batch_size, | |
horizon, | |
num_channels_latents, | |
H, | |
W, | |
image_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) | |
depth_latent = torch.cat([depth_latent, chunk], 1) | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising first chunk", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
curr_timesteps = torch.tensor([t]*horizon).to(device) | |
depth_latent = self.scheduler.scale_model_input(depth_latent, t) | |
noise_pred = self.unet( | |
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], depth_latent[:, start_frame:]], dim=2), | |
curr_timesteps[start_frame:], | |
image_embeddings[start_frame:curr_frame+horizon], | |
added_time_ids=added_time_ids | |
)[0] | |
depth_latent[:, curr_frame:] = self.scheduler.step(noise_pred[:,-horizon:], t, depth_latent[:, curr_frame:]).prev_sample | |
self.scheduler._step_index = None | |
curr_frame += horizon | |
pbar.update(horizon) | |
while curr_frame < n_frames: | |
if self.chunk_size > 0: | |
horizon = min(n_frames - curr_frame, self.chunk_size) | |
else: | |
horizon = min(n_frames - curr_frame, self.n_tokens) | |
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." | |
start_frame = max(0, curr_frame + horizon - self.n_tokens) | |
chunk = self.prepare_latents( | |
batch_size, | |
curr_frame+horizon-start_frame, | |
num_channels_latents, | |
H, | |
W, | |
image_embeddings.dtype, | |
device, | |
generator, | |
latents, | |
) | |
pbar.set_postfix( | |
{ | |
"start": start_frame, | |
"end": curr_frame + horizon, | |
} | |
) | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising ", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
curr_timesteps = torch.tensor([t]*(curr_frame+horizon-start_frame)).to(device) | |
chunk = self.scheduler.scale_model_input(chunk, t) | |
noise_pred = self.unet( | |
torch.cat([rgb_latent[:, start_frame:curr_frame+horizon], chunk], dim=2), | |
curr_timesteps, | |
image_embeddings[start_frame:curr_frame+horizon], | |
added_time_ids=added_time_ids | |
)[0] | |
chunk = self.scheduler.step(noise_pred, t, chunk).prev_sample | |
depth_latent = torch.cat([depth_latent, chunk[:, -horizon:]], 1) | |
self.scheduler._step_index = None | |
curr_frame += horizon | |
pbar.update(horizon) | |
torch.cuda.empty_cache() | |
if needs_upcasting: | |
self.vae.to(dtype=torch.float16) | |
depth = self.decode_depth(depth_latent, decode_chunk_size=decode_chunk_size) | |
# clip prediction | |
depth = torch.clip(depth, -1.0, 1.0) | |
# shift to [0, 1] | |
depth = (depth + 1.0) / 2.0 | |
return depth.squeeze(0) |