bubbliiiing
update v3
e262715
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.init as init
from einops import rearrange
from torch import nn
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Patch1D(nn.Module):
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
stride: int = 2,
padding: int = 0,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding)
init.constant_(self.conv.weight, 0.0)
with torch.no_grad():
for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride
init.constant_(self.conv.bias, 0.0)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
return self.conv(inputs)
class UnPatch1D(nn.Module):
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class Upsampler(nn.Module):
def __init__(
self,
spatial_upsample_factor: int = 1,
temporal_upsample_factor: int = 1,
):
super().__init__()
self.spatial_upsample_factor = spatial_upsample_factor
self.temporal_upsample_factor = temporal_upsample_factor
class TemporalUpsampler3D(Upsampler):
def __init__(self):
super().__init__(
spatial_upsample_factor=1,
temporal_upsample_factor=2,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[2] > 1:
first_frame, x = x[:, :, :1], x[:, :, 1:]
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
x = torch.cat([first_frame, x], dim=2)
return x
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
class CausalConv3d(nn.Conv3d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3, # : int | tuple[int, int, int],
stride=1, # : int | tuple[int, int, int] = 1,
padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
dilation=1, # : int | tuple[int, int, int] = 1,
**kwargs,
):
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
stride = stride if isinstance(stride, tuple) else (stride,) * 3
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
t_ks, h_ks, w_ks = kernel_size
_, h_stride, w_stride = stride
t_dilation, h_dilation, w_dilation = dilation
t_pad = (t_ks - 1) * t_dilation
# TODO: align with SD
if padding is None:
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
elif isinstance(padding, int):
h_pad = w_pad = padding
else:
assert NotImplementedError
self.temporal_padding = t_pad
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=(0, h_pad, w_pad),
**kwargs,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
return super().forward(x)
class PatchEmbed3D(nn.Module):
"""3D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
time_patch_size=4,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv3d(
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
latent = rearrange(latent, "b c f h w -> (b f) c h w")
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class PatchEmbedF3D(nn.Module):
"""Fake 3D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.proj_t = Patch1D(
embed_dim, True, stride=patch_size
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
b, c, f, h, w = latent.size()
latent = rearrange(latent, "b c f h w -> (b f) c h w")
latent = self.proj(latent)
latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f)
latent = rearrange(latent, "b c f h w -> (b h w) c f")
latent = self.proj_t(latent)
latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2)
latent = rearrange(latent, "b c f h w -> (b f) c h w")
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class CasualPatchEmbed3D(nn.Module):
"""3D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
time_patch_size=4,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = CausalConv3d(
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
latent = rearrange(latent, "b c f h w -> (b f) c h w")
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)