Spaces:
Running
Running
import math | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
from einops import rearrange | |
from torch import nn | |
def get_2d_sincos_pos_embed( | |
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 | |
): | |
""" | |
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or | |
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
""" | |
if isinstance(grid_size, int): | |
grid_size = (grid_size, grid_size) | |
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale | |
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale | |
grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
grid = np.stack(grid, axis=0) | |
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) | |
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
if cls_token and extra_tokens > 0: | |
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
if embed_dim % 2 != 0: | |
raise ValueError("embed_dim must be divisible by 2") | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | |
""" | |
if embed_dim % 2 != 0: | |
raise ValueError("embed_dim must be divisible by 2") | |
omega = np.arange(embed_dim // 2, dtype=np.float64) | |
omega /= embed_dim / 2.0 | |
omega = 1.0 / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
class Patch1D(nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool = False, | |
out_channels: Optional[int] = None, | |
stride: int = 2, | |
padding: int = 0, | |
name: str = "conv", | |
): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.padding = padding | |
self.name = name | |
if use_conv: | |
self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding) | |
init.constant_(self.conv.weight, 0.0) | |
with torch.no_grad(): | |
for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride | |
init.constant_(self.conv.bias, 0.0) | |
else: | |
assert self.channels == self.out_channels | |
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
assert inputs.shape[1] == self.channels | |
return self.conv(inputs) | |
class UnPatch1D(nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
use_conv: bool = False, | |
use_conv_transpose: bool = False, | |
out_channels: Optional[int] = None, | |
name: str = "conv", | |
): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.use_conv_transpose = use_conv_transpose | |
self.name = name | |
self.conv = None | |
if use_conv_transpose: | |
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) | |
elif use_conv: | |
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
assert inputs.shape[1] == self.channels | |
if self.use_conv_transpose: | |
return self.conv(inputs) | |
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") | |
if self.use_conv: | |
outputs = self.conv(outputs) | |
return outputs | |
class Upsampler(nn.Module): | |
def __init__( | |
self, | |
spatial_upsample_factor: int = 1, | |
temporal_upsample_factor: int = 1, | |
): | |
super().__init__() | |
self.spatial_upsample_factor = spatial_upsample_factor | |
self.temporal_upsample_factor = temporal_upsample_factor | |
class TemporalUpsampler3D(Upsampler): | |
def __init__(self): | |
super().__init__( | |
spatial_upsample_factor=1, | |
temporal_upsample_factor=2, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if x.shape[2] > 1: | |
first_frame, x = x[:, :, :1], x[:, :, 1:] | |
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear") | |
x = torch.cat([first_frame, x], dim=2) | |
return x | |
def cast_tuple(t, length = 1): | |
return t if isinstance(t, tuple) else ((t,) * length) | |
def divisible_by(num, den): | |
return (num % den) == 0 | |
def is_odd(n): | |
return not divisible_by(n, 2) | |
class CausalConv3d(nn.Conv3d): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size=3, # : int | tuple[int, int, int], | |
stride=1, # : int | tuple[int, int, int] = 1, | |
padding=1, # : int | tuple[int, int, int], # TODO: change it to 0. | |
dilation=1, # : int | tuple[int, int, int] = 1, | |
**kwargs, | |
): | |
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 | |
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead." | |
stride = stride if isinstance(stride, tuple) else (stride,) * 3 | |
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead." | |
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3 | |
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead." | |
t_ks, h_ks, w_ks = kernel_size | |
_, h_stride, w_stride = stride | |
t_dilation, h_dilation, w_dilation = dilation | |
t_pad = (t_ks - 1) * t_dilation | |
# TODO: align with SD | |
if padding is None: | |
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2) | |
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2) | |
elif isinstance(padding, int): | |
h_pad = w_pad = padding | |
else: | |
assert NotImplementedError | |
self.temporal_padding = t_pad | |
super().__init__( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
dilation=dilation, | |
padding=(0, h_pad, w_pad), | |
**kwargs, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: (B, C, T, H, W) | |
x = F.pad( | |
x, | |
pad=(0, 0, 0, 0, self.temporal_padding, 0), | |
mode="replicate", # TODO: check if this is necessary | |
) | |
return super().forward(x) | |
class PatchEmbed3D(nn.Module): | |
"""3D Image to Patch Embedding""" | |
def __init__( | |
self, | |
height=224, | |
width=224, | |
patch_size=16, | |
time_patch_size=4, | |
in_channels=3, | |
embed_dim=768, | |
layer_norm=False, | |
flatten=True, | |
bias=True, | |
interpolation_scale=1, | |
): | |
super().__init__() | |
num_patches = (height // patch_size) * (width // patch_size) | |
self.flatten = flatten | |
self.layer_norm = layer_norm | |
self.proj = nn.Conv3d( | |
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias | |
) | |
if layer_norm: | |
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
else: | |
self.norm = None | |
self.patch_size = patch_size | |
# See: | |
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
self.height, self.width = height // patch_size, width // patch_size | |
self.base_size = height // patch_size | |
self.interpolation_scale = interpolation_scale | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
) | |
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
def forward(self, latent): | |
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
latent = self.proj(latent) | |
latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
if self.flatten: | |
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC | |
if self.layer_norm: | |
latent = self.norm(latent) | |
# Interpolate positional embeddings if needed. | |
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
if self.height != height or self.width != width: | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim=self.pos_embed.shape[-1], | |
grid_size=(height, width), | |
base_size=self.base_size, | |
interpolation_scale=self.interpolation_scale, | |
) | |
pos_embed = torch.from_numpy(pos_embed) | |
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
else: | |
pos_embed = self.pos_embed | |
return (latent + pos_embed).to(latent.dtype) | |
class PatchEmbedF3D(nn.Module): | |
"""Fake 3D Image to Patch Embedding""" | |
def __init__( | |
self, | |
height=224, | |
width=224, | |
patch_size=16, | |
in_channels=3, | |
embed_dim=768, | |
layer_norm=False, | |
flatten=True, | |
bias=True, | |
interpolation_scale=1, | |
): | |
super().__init__() | |
num_patches = (height // patch_size) * (width // patch_size) | |
self.flatten = flatten | |
self.layer_norm = layer_norm | |
self.proj = nn.Conv2d( | |
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
) | |
self.proj_t = Patch1D( | |
embed_dim, True, stride=patch_size | |
) | |
if layer_norm: | |
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
else: | |
self.norm = None | |
self.patch_size = patch_size | |
# See: | |
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
self.height, self.width = height // patch_size, width // patch_size | |
self.base_size = height // patch_size | |
self.interpolation_scale = interpolation_scale | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
) | |
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
def forward(self, latent): | |
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
b, c, f, h, w = latent.size() | |
latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
latent = self.proj(latent) | |
latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f) | |
latent = rearrange(latent, "b c f h w -> (b h w) c f") | |
latent = self.proj_t(latent) | |
latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2) | |
latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
if self.flatten: | |
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC | |
if self.layer_norm: | |
latent = self.norm(latent) | |
# Interpolate positional embeddings if needed. | |
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
if self.height != height or self.width != width: | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim=self.pos_embed.shape[-1], | |
grid_size=(height, width), | |
base_size=self.base_size, | |
interpolation_scale=self.interpolation_scale, | |
) | |
pos_embed = torch.from_numpy(pos_embed) | |
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
else: | |
pos_embed = self.pos_embed | |
return (latent + pos_embed).to(latent.dtype) | |
class CasualPatchEmbed3D(nn.Module): | |
"""3D Image to Patch Embedding""" | |
def __init__( | |
self, | |
height=224, | |
width=224, | |
patch_size=16, | |
time_patch_size=4, | |
in_channels=3, | |
embed_dim=768, | |
layer_norm=False, | |
flatten=True, | |
bias=True, | |
interpolation_scale=1, | |
): | |
super().__init__() | |
num_patches = (height // patch_size) * (width // patch_size) | |
self.flatten = flatten | |
self.layer_norm = layer_norm | |
self.proj = CausalConv3d( | |
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None | |
) | |
if layer_norm: | |
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
else: | |
self.norm = None | |
self.patch_size = patch_size | |
# See: | |
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
self.height, self.width = height // patch_size, width // patch_size | |
self.base_size = height // patch_size | |
self.interpolation_scale = interpolation_scale | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
) | |
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
def forward(self, latent): | |
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
latent = self.proj(latent) | |
latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
if self.flatten: | |
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC | |
if self.layer_norm: | |
latent = self.norm(latent) | |
# Interpolate positional embeddings if needed. | |
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
if self.height != height or self.width != width: | |
pos_embed = get_2d_sincos_pos_embed( | |
embed_dim=self.pos_embed.shape[-1], | |
grid_size=(height, width), | |
base_size=self.base_size, | |
interpolation_scale=self.interpolation_scale, | |
) | |
pos_embed = torch.from_numpy(pos_embed) | |
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
else: | |
pos_embed = self.pos_embed | |
return (latent + pos_embed).to(latent.dtype) | |