import torch from torch import nn from diffusers import AutoencoderKL from einops import rearrange from torch import Tensor from torch.nn import functional from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder class Downsample3D(nn.Module): def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1): super().__init__() stride: int = 2 self.padding = padding self.in_channels = in_channels self.dims = dims self.conv = make_conv_nd( dims=dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, ) def forward(self, x, downsample_in_time=True): conv = self.conv if self.padding == 0: if self.dims == 2: padding = (0, 1, 0, 1) else: padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) x = functional.pad(x, padding, mode="constant", value=0) if self.dims == (2, 1) and not downsample_in_time: return conv(x, skip_time_conv=True) return conv(x) def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor: """ Encodes media items (images or videos) into latent representations using a specified VAE model. The function supports processing batches of images or video frames and can handle the processing in smaller sub-batches if needed. Args: media_items (Tensor): A torch Tensor containing the media items to encode. The expected shape is (batch_size, channels, height, width) for images or (batch_size, channels, frames, height, width) for videos. vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, pre-configured and loaded with the appropriate model weights. split_size (int, optional): The number of sub-batches to split the input batch into for encoding. If set to more than 1, the input media items are processed in smaller batches according to this value. Defaults to 1, which processes all items in a single batch. Returns: Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted to match the input shape, scaled by the model's configuration. Examples: >>> import torch >>> from diffusers import AutoencoderKL >>> vae = AutoencoderKL.from_pretrained('your-model-name') >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. >>> latents = vae_encode(images, vae) >>> print(latents.shape) # Output shape will depend on the model's latent configuration. Note: In case of a video, the function encodes the media item frame-by frame. """ is_video_shaped = media_items.dim() == 5 batch_size, channels = media_items.shape[0:2] if channels != 3: raise ValueError(f"Expects tensors with 3 channels, got {channels}.") if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): media_items = rearrange(media_items, "b c n h w -> (b n) c h w") if split_size > 1: if len(media_items) % split_size != 0: raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") encode_bs = len(media_items) // split_size # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] latents = [] for image_batch in media_items.split(encode_bs): latents.append(vae.encode(image_batch).latent_dist.sample()) latents = torch.cat(latents, dim=0) else: latents = vae.encode(media_items).latent_dist.sample() latents = normalize_latents(latents, vae, vae_per_channel_normalize) if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) return latents def vae_decode( latents: Tensor, vae: AutoencoderKL, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize=False ) -> Tensor: is_video_shaped = latents.dim() == 5 batch_size = latents.shape[0] if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): latents = rearrange(latents, "b c n h w -> (b n) c h w") if split_size > 1: if len(latents) % split_size != 0: raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split") encode_bs = len(latents) // split_size image_batch = [ _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize) for latent_batch in latents.split(encode_bs) ] images = torch.cat(image_batch, dim=0) else: images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize) if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)): images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) return images def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor: if isinstance(vae, (CausalVideoAutoencoder)): *_, fl, hl, wl = latents.shape temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) latents = latents.to(vae.dtype) image = vae.decode( un_normalize_latents(latents, vae, vae_per_channel_normalize), return_dict=False, target_shape=(1, 3, fl * temporal_scale if is_video else 1, hl * spatial_scale, wl * spatial_scale), )[0] else: image = vae.decode( un_normalize_latents(latents, vae, vae_per_channel_normalize), return_dict=False, )[0] return image def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: if isinstance(vae, CausalVideoAutoencoder): spatial = vae.spatial_downscale_factor temporal = vae.temporal_downscale_factor else: down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)]) spatial = vae.config.patch_size * 2**down_blocks temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae) else 1 return (temporal, spatial, spatial) def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: return ( (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) if vae_per_channel_normalize else latents * vae.config.scaling_factor ) def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor: return ( latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) if vae_per_channel_normalize else latents / vae.config.scaling_factor )