|
""" |
|
https://github.com/lucidrains/make-a-video-pytorch |
|
""" |
|
|
|
import math |
|
import functools |
|
from operator import mul |
|
|
|
import torch |
|
from torch import nn, einsum |
|
import torch.nn.functional as F |
|
|
|
from einops import rearrange, repeat, pack, unpack |
|
from einops.layers.torch import Rearrange |
|
|
|
from .modules_conv import avg_pool_nd, zero_module, normalization, conv_nd |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def mul_reduce(tup): |
|
return functools.reduce(mul, tup) |
|
|
|
|
|
def divisible_by(numer, denom): |
|
return (numer % denom) == 0 |
|
|
|
|
|
mlist = nn.ModuleList |
|
|
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
def __init__(self, dim, theta = 10000): |
|
super().__init__() |
|
self.theta = theta |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
dtype, device = x.dtype, x.device |
|
assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type' |
|
|
|
half_dim = self.dim // 2 |
|
emb = math.log(self.theta) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) |
|
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') |
|
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype) |
|
|
|
|
|
class ChanLayerNorm(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.g = nn.Parameter(torch.ones(dim, 1, 1, 1)) |
|
|
|
def forward(self, x): |
|
eps = 1e-5 if x.dtype == torch.float32 else 1e-3 |
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) |
|
mean = torch.mean(x, dim = 1, keepdim = True) |
|
x = (x - mean) * var.clamp(min = eps).rsqrt() |
|
dtype = self.g.dtype |
|
return x.to(dtype) * self.g |
|
|
|
|
|
def shift_token(t): |
|
t, t_shift = t.chunk(2, dim = 1) |
|
t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value = 0.) |
|
return torch.cat((t, t_shift), dim = 1) |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.g = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
eps = 1e-5 if x.dtype == torch.float32 else 1e-3 |
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) |
|
mean = torch.mean(x, dim = 1, keepdim = True) |
|
return (x - mean) * var.clamp(min = eps).rsqrt() * self.g |
|
|
|
|
|
|
|
|
|
class GEGLU(nn.Module): |
|
def forward(self, x): |
|
x = x.float() |
|
x, gate = x.chunk(2, dim = 1) |
|
return x * F.gelu(gate) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, mult = 4): |
|
super().__init__() |
|
|
|
inner_dim = int(dim * mult * 2 / 3) |
|
self.proj_in = nn.Sequential( |
|
nn.Conv3d(dim, inner_dim * 2, 1, bias = False), |
|
GEGLU() |
|
) |
|
|
|
self.proj_out = nn.Sequential( |
|
ChanLayerNorm(inner_dim), |
|
nn.Conv3d(inner_dim, dim, 1, bias = False) |
|
) |
|
|
|
def forward(self, x, enable_time=True): |
|
x = self.proj_in(x) |
|
if enable_time: |
|
x = shift_token(x) |
|
return self.proj_out(x) |
|
|
|
|
|
|
|
|
|
|
|
class ContinuousPositionBias(nn.Module): |
|
""" from https://arxiv.org/abs/2111.09883 """ |
|
|
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
heads, |
|
num_dims = 1, |
|
layers = 2, |
|
log_dist = True, |
|
cache_rel_pos = False |
|
): |
|
super().__init__() |
|
self.num_dims = num_dims |
|
self.log_dist = log_dist |
|
|
|
self.net = nn.ModuleList([]) |
|
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU())) |
|
|
|
for _ in range(layers - 1): |
|
self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU())) |
|
|
|
self.net.append(nn.Linear(dim, heads)) |
|
|
|
self.cache_rel_pos = cache_rel_pos |
|
self.register_buffer('rel_pos', None, persistent = False) |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
def forward(self, *dimensions): |
|
device = self.device |
|
|
|
if not exists(self.rel_pos) or not self.cache_rel_pos: |
|
positions = [torch.arange(d, device = device) for d in dimensions] |
|
grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij')) |
|
grid = rearrange(grid, 'c ... -> (...) c') |
|
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') |
|
|
|
if self.log_dist: |
|
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1) |
|
|
|
self.register_buffer('rel_pos', rel_pos, persistent = False) |
|
|
|
rel_pos = self.rel_pos.to(self.dtype) |
|
|
|
for layer in self.net: |
|
rel_pos = layer(rel_pos) |
|
|
|
return rearrange(rel_pos, 'i j h -> h i j') |
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_head = 64, |
|
heads = 8 |
|
): |
|
super().__init__() |
|
self.heads = heads |
|
self.scale = dim_head ** -0.5 |
|
inner_dim = dim_head * heads |
|
self.norm = LayerNorm(dim) |
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False) |
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) |
|
self.to_out = nn.Linear(inner_dim, dim, bias = False) |
|
|
|
nn.init.zeros_(self.to_out.weight.data) |
|
|
|
self.pos_embeds = nn.Parameter(torch.randn([1, 30, dim])) |
|
self.frame_rate_embeds = nn.Parameter(torch.randn([1, 30, dim])) |
|
|
|
def forward( |
|
self, |
|
x, |
|
context = None, |
|
rel_pos_bias = None, |
|
framerate = None, |
|
): |
|
if framerate is not None: |
|
x = x + self.pos_embeds[:, :x.shape[1]].repeat(x.shape[0], 1, 1) |
|
x = x + self.frame_rate_embeds[:, framerate-1:framerate].repeat(x.shape[0], x.shape[1], 1) |
|
|
|
if context is None: |
|
context = x |
|
|
|
x = self.norm(x) |
|
context = self.norm(context) |
|
|
|
q, k, v = self.to_q(x), *self.to_kv(context).chunk(2, dim = -1) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) |
|
|
|
q = q * self.scale |
|
sim = einsum('b h i d, b h j d -> b h i j', q, k) |
|
|
|
if exists(rel_pos_bias): |
|
sim = sim + rel_pos_bias |
|
|
|
attn = sim.softmax(dim = -1) |
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v) |
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
class PseudoConv3d(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out = None, |
|
kernel_size = 3, |
|
*, |
|
temporal_kernel_size = None, |
|
**kwargs |
|
): |
|
super().__init__() |
|
dim_out = default(dim_out, dim) |
|
temporal_kernel_size = default(temporal_kernel_size, kernel_size) |
|
|
|
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2) |
|
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2) if kernel_size > 1 else None |
|
|
|
if exists(self.temporal_conv): |
|
nn.init.dirac_(self.temporal_conv.weight.data) |
|
nn.init.zeros_(self.temporal_conv.bias.data) |
|
|
|
def forward( |
|
self, |
|
x, |
|
enable_time = True |
|
): |
|
b, c, *_, h, w = x.shape |
|
|
|
is_video = x.ndim == 5 |
|
enable_time &= is_video |
|
|
|
if is_video: |
|
x = rearrange(x, 'b c t h w -> (b t) c h w') |
|
|
|
x = self.spatial_conv(x) |
|
|
|
if is_video: |
|
x = rearrange(x, '(b t) c h w -> b c t h w', b = b) |
|
|
|
if not enable_time or not exists(self.temporal_conv): |
|
return x |
|
|
|
x = rearrange(x, 'b c t h w -> (b h w) c t') |
|
|
|
x = self.temporal_conv(x) |
|
|
|
x = rearrange(x, '(b h w) c t -> b c t h w', h = h, w = w) |
|
|
|
return x |
|
|
|
|
|
def frame_shift(x, shift_num=8): |
|
num_frame = x.shape[2] |
|
x = list(x.chunk(shift_num, 1)) |
|
for i in range(shift_num): |
|
if i > 0: |
|
shifted = torch.cat([torch.zeros_like(x[i][:, :, :i]), x[i][:, :, :-i]], 2) |
|
else: |
|
shifted = x[i] |
|
x[i] = shifted |
|
return torch.cat(x, 1) |
|
|
|
|
|
class ResBlockFrameShift(nn.Module): |
|
""" |
|
A residual block that can optionally change the number of channels. |
|
:param channels: the number of input channels. |
|
:param emb_channels: the number of timestep embedding channels. |
|
:param dropout: the rate of dropout. |
|
:param out_channels: if specified, the number of out channels. |
|
:param use_conv: if True and out_channels is specified, use a spatial |
|
convolution instead of a smaller 1x1 convolution to change the |
|
channels in the skip connection. |
|
:param dims: determines if the signal is 1D, 2D, or 3D. |
|
:param use_checkpoint: if True, use gradient checkpointing on this module. |
|
:param up: if True, use this block for upsampling. |
|
:param down: if True, use this block for downsampling. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels, |
|
dropout, |
|
out_channels=None, |
|
use_conv=False, |
|
dims=2, |
|
use_checkpoint=False, |
|
up=False, |
|
down=False, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.dropout = dropout |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.use_checkpoint = use_checkpoint |
|
|
|
self.out_layers = nn.Sequential( |
|
normalization(self.channels), |
|
nn.SiLU(), |
|
zero_module( |
|
conv_nd(dims, self.channels, self.out_channels, 3, padding=1) |
|
), |
|
) |
|
|
|
if self.out_channels == channels: |
|
self.skip_connection = nn.Identity() |
|
elif use_conv: |
|
self.skip_connection = conv_nd( |
|
dims, channels, self.out_channels, 3, padding=1 |
|
) |
|
else: |
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) |
|
|
|
def forward(self, x): |
|
""" |
|
Apply the block to a Tensor, conditioned on a timestep embedding. |
|
:param x: an [N x C x ...] Tensor of features. |
|
:return: an [N x C x ...] Tensor of outputs. |
|
""" |
|
num_frames = x.shape[2] |
|
x = rearrange(x, 'b c t h w -> (b t) c h w') |
|
|
|
h = self.out_layers(x) |
|
|
|
h = rearrange(h, '(b t) c h w -> b c t h w', t=num_frames) |
|
h = frame_shift(h) |
|
h = rearrange(h, 'b c t h w -> (b t) c h w') |
|
|
|
out = self.skip_connection(x) + h |
|
out = rearrange(out, '(b t) c h w -> b c t h w', t=num_frames) |
|
return out |
|
|
|
|
|
class ResBlockVideo(nn.Module): |
|
""" |
|
A residual block that can optionally change the number of channels. |
|
:param channels: the number of input channels. |
|
:param emb_channels: the number of timestep embedding channels. |
|
:param dropout: the rate of dropout. |
|
:param out_channels: if specified, the number of out channels. |
|
:param use_conv: if True and out_channels is specified, use a spatial |
|
convolution instead of a smaller 1x1 convolution to change the |
|
channels in the skip connection. |
|
:param dims: determines if the signal is 1D, 2D, or 3D. |
|
:param use_checkpoint: if True, use gradient checkpointing on this module. |
|
:param up: if True, use this block for upsampling. |
|
:param down: if True, use this block for downsampling. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels, |
|
dropout, |
|
out_channels=None, |
|
use_conv=False, |
|
use_scale_shift_norm=False, |
|
dims=2, |
|
use_checkpoint=False, |
|
up=False, |
|
down=False, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.dropout = dropout |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.use_checkpoint = use_checkpoint |
|
self.use_scale_shift_norm = use_scale_shift_norm |
|
|
|
self.in_layers = nn.Sequential( |
|
normalization(channels), |
|
nn.SiLU(), |
|
conv_nd(dims, channels, self.out_channels, 3, padding=1), |
|
) |
|
|
|
self.out_layers = nn.Sequential( |
|
normalization(self.out_channels), |
|
nn.SiLU(), |
|
nn.Dropout(p=dropout), |
|
zero_module( |
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) |
|
), |
|
) |
|
|
|
if self.out_channels == channels: |
|
self.skip_connection = nn.Identity() |
|
elif use_conv: |
|
self.skip_connection = conv_nd( |
|
dims, channels, self.out_channels, 3, padding=1 |
|
) |
|
else: |
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) |
|
|
|
def forward(self, x): |
|
""" |
|
Apply the block to a Tensor, conditioned on a timestep embedding. |
|
:param x: an [N x C x ...] Tensor of features. |
|
:return: an [N x C x ...] Tensor of outputs. |
|
""" |
|
num_frames = x.shape[2] |
|
x = rearrange(x, 'b c t h w -> (b t) c h w ') |
|
|
|
h = x |
|
h = self.in_layers(h) |
|
h = self.out_layers(h) |
|
|
|
out = self.skip_connection(x) + h |
|
out = rearrange(out, '(b t) c h w -> b c t h w', t=num_frames) |
|
return out |
|
|
|
|
|
class Downsample3D(nn.Module): |
|
""" |
|
A downsampling layer with an optional convolution. |
|
:param channels: channels in the inputs and outputs. |
|
:param use_conv: a bool determining if a convolution is applied. |
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
|
downsampling occurs in the inner-two dimensions. |
|
""" |
|
|
|
def __init__(self, channels, use_conv, dims=2, stride=None, out_channels=None, padding=1): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.dims = dims |
|
stride = 1 |
|
if use_conv: |
|
self.op = conv_nd( |
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding |
|
) |
|
else: |
|
assert self.channels == self.out_channels |
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) |
|
|
|
def forward(self, x): |
|
assert x.shape[1] == self.channels |
|
return self.op(x) |
|
|
|
|
|
class SpatioTemporalAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
dim_head = 64, |
|
heads = 8, |
|
use_resnet = False, |
|
use_frame_shift = True, |
|
use_context_att = False, |
|
use_temp_att = True, |
|
use_context = False, |
|
): |
|
super().__init__() |
|
self.use_resnet = use_resnet |
|
self.use_frame_shift = use_frame_shift |
|
self.use_context_att = use_context_att |
|
self.use_temp_att = use_temp_att |
|
|
|
if use_resnet: |
|
self.resblock = ResBlockVideo(dim, dropout=0, dims=2) |
|
if use_frame_shift: |
|
self.frameshiftblock = ResBlockFrameShift(dim, dropout=0, dims=2) |
|
if use_context_att: |
|
self.downsample_x0 = Downsample3D(4, True, 2, out_channels=dim) |
|
self.temporal_attn_x0 = Attention(dim = dim, dim_head = dim_head, heads = heads) |
|
|
|
if use_temp_att: |
|
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) |
|
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) |
|
|
|
self.ff = FeedForward(dim = dim, mult = 4) |
|
|
|
def forward( |
|
self, |
|
x, |
|
x_0 = None, |
|
enable_time = True, |
|
framerate = 4, |
|
is_video = False, |
|
): |
|
|
|
x_ndim = x.ndim |
|
is_video = x_ndim == 5 or is_video |
|
enable_time &= is_video |
|
|
|
if enable_time: |
|
img_size = x.shape[-1] |
|
if self.use_temp_att: |
|
if x_ndim == 5: |
|
b, c, *_, h, w = x.shape |
|
x = rearrange(x, 'b c t h w -> (b h w) t c') |
|
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) |
|
|
|
if self.use_context_att and x_0 is not None: |
|
x_0_img_size = x_0.shape[-1] |
|
kernel_size = x_0_img_size // img_size |
|
x_0 = F.avg_pool2d(x_0, [kernel_size, kernel_size], stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) |
|
x_0 = self.downsample_x0(x_0).unsqueeze(2) |
|
if x_ndim == 5: |
|
x_0 = rearrange(x_0, 'b c t h w -> (b h w) t c') |
|
x = self.temporal_attn_x0(x, context=x_0, rel_pos_bias = time_rel_pos_bias, framerate = framerate) + x |
|
|
|
if self.use_temp_att: |
|
x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias, framerate = framerate) + x |
|
if x_ndim == 5: |
|
x = rearrange(x, '(b h w) t c -> b c t h w', w = w, h = h) |
|
x = self.ff(x, enable_time=enable_time) + x |
|
|
|
if self.use_frame_shift: |
|
x = self.frameshiftblock(x) |
|
|
|
if self.use_resnet: |
|
x = self.resblock(x) |
|
return x |
|
|