dmolino's picture
Upload 225 files
9a7fe1f verified
raw
history blame
17.4 kB
"""
https://github.com/lucidrains/make-a-video-pytorch
"""
import math
import functools
from operator import mul
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
from .modules_conv import avg_pool_nd, zero_module, normalization, conv_nd
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def mul_reduce(tup):
return functools.reduce(mul, tup)
def divisible_by(numer, denom):
return (numer % denom) == 0
mlist = nn.ModuleList
# for time conditioning
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.theta = theta
self.dim = dim
def forward(self, x):
dtype, device = x.dtype, x.device
assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type'
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
class ChanLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(dim, 1, 1, 1))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
x = (x - mean) * var.clamp(min = eps).rsqrt()
dtype = self.g.dtype
return x.to(dtype) * self.g
def shift_token(t):
t, t_shift = t.chunk(2, dim = 1)
t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value = 0.)
return torch.cat((t, t_shift), dim = 1)
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * var.clamp(min = eps).rsqrt() * self.g
# feedforward
class GEGLU(nn.Module):
def forward(self, x):
x = x.float()
x, gate = x.chunk(2, dim = 1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
super().__init__()
inner_dim = int(dim * mult * 2 / 3)
self.proj_in = nn.Sequential(
nn.Conv3d(dim, inner_dim * 2, 1, bias = False),
GEGLU()
)
self.proj_out = nn.Sequential(
ChanLayerNorm(inner_dim),
nn.Conv3d(inner_dim, dim, 1, bias = False)
)
def forward(self, x, enable_time=True):
x = self.proj_in(x)
if enable_time:
x = shift_token(x)
return self.proj_out(x)
# feedforwa
# best relative positional encoding
class ContinuousPositionBias(nn.Module):
""" from https://arxiv.org/abs/2111.09883 """
def __init__(
self,
*,
dim,
heads,
num_dims = 1,
layers = 2,
log_dist = True,
cache_rel_pos = False
):
super().__init__()
self.num_dims = num_dims
self.log_dist = log_dist
self.net = nn.ModuleList([])
self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU()))
for _ in range(layers - 1):
self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU()))
self.net.append(nn.Linear(dim, heads))
self.cache_rel_pos = cache_rel_pos
self.register_buffer('rel_pos', None, persistent = False)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def forward(self, *dimensions):
device = self.device
if not exists(self.rel_pos) or not self.cache_rel_pos:
positions = [torch.arange(d, device = device) for d in dimensions]
grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij'))
grid = rearrange(grid, 'c ... -> (...) c')
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
if self.log_dist:
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
self.register_buffer('rel_pos', rel_pos, persistent = False)
rel_pos = self.rel_pos.to(self.dtype)
for layer in self.net:
rel_pos = layer(rel_pos)
return rearrange(rel_pos, 'i j h -> h i j')
# helper classes
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head = 64,
heads = 8
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
nn.init.zeros_(self.to_out.weight.data) # identity with skip connection
self.pos_embeds = nn.Parameter(torch.randn([1, 30, dim]))
self.frame_rate_embeds = nn.Parameter(torch.randn([1, 30, dim]))
def forward(
self,
x,
context = None,
rel_pos_bias = None,
framerate = None,
):
if framerate is not None:
x = x + self.pos_embeds[:, :x.shape[1]].repeat(x.shape[0], 1, 1)
x = x + self.frame_rate_embeds[:, framerate-1:framerate].repeat(x.shape[0], x.shape[1], 1)
if context is None:
context = x
x = self.norm(x)
context = self.norm(context)
q, k, v = self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(rel_pos_bias):
sim = sim + rel_pos_bias
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# main contribution - pseudo 3d conv
class PseudoConv3d(nn.Module):
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2) if kernel_size > 1 else None
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)
def forward(
self,
x,
enable_time = True
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
enable_time &= is_video
if is_video:
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.spatial_conv(x)
if is_video:
x = rearrange(x, '(b t) c h w -> b c t h w', b = b)
if not enable_time or not exists(self.temporal_conv):
return x
x = rearrange(x, 'b c t h w -> (b h w) c t')
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c t -> b c t h w', h = h, w = w)
return x
def frame_shift(x, shift_num=8):
num_frame = x.shape[2]
x = list(x.chunk(shift_num, 1))
for i in range(shift_num):
if i > 0:
shifted = torch.cat([torch.zeros_like(x[i][:, :, :i]), x[i][:, :, :-i]], 2)
else:
shifted = x[i]
x[i] = shifted
return torch.cat(x, 1)
class ResBlockFrameShift(nn.Module):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.out_layers = nn.Sequential(
normalization(self.channels),
nn.SiLU(),
zero_module(
conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:return: an [N x C x ...] Tensor of outputs.
"""
num_frames = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
h = self.out_layers(x)
h = rearrange(h, '(b t) c h w -> b c t h w', t=num_frames)
h = frame_shift(h)
h = rearrange(h, 'b c t h w -> (b t) c h w')
out = self.skip_connection(x) + h
out = rearrange(out, '(b t) c h w -> b c t h w', t=num_frames)
return out
class ResBlockVideo(nn.Module):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:return: an [N x C x ...] Tensor of outputs.
"""
num_frames = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w ')
h = x
h = self.in_layers(h)
h = self.out_layers(h)
out = self.skip_connection(x) + h
out = rearrange(out, '(b t) c h w -> b c t h w', t=num_frames)
return out
class Downsample3D(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, stride=None, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 1
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class SpatioTemporalAttention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
use_resnet = False,
use_frame_shift = True,
use_context_att = False,
use_temp_att = True,
use_context = False,
):
super().__init__()
self.use_resnet = use_resnet
self.use_frame_shift = use_frame_shift
self.use_context_att = use_context_att
self.use_temp_att = use_temp_att
if use_resnet:
self.resblock = ResBlockVideo(dim, dropout=0, dims=2)
if use_frame_shift:
self.frameshiftblock = ResBlockFrameShift(dim, dropout=0, dims=2)
if use_context_att:
self.downsample_x0 = Downsample3D(4, True, 2, out_channels=dim)
self.temporal_attn_x0 = Attention(dim = dim, dim_head = dim_head, heads = heads)
if use_temp_att:
self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1)
self.ff = FeedForward(dim = dim, mult = 4)
def forward(
self,
x,
x_0 = None,
enable_time = True,
framerate = 4,
is_video = False,
):
x_ndim = x.ndim
is_video = x_ndim == 5 or is_video
enable_time &= is_video
if enable_time:
img_size = x.shape[-1]
if self.use_temp_att:
if x_ndim == 5:
b, c, *_, h, w = x.shape
x = rearrange(x, 'b c t h w -> (b h w) t c')
time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1])
if self.use_context_att and x_0 is not None:
x_0_img_size = x_0.shape[-1]
kernel_size = x_0_img_size // img_size
x_0 = F.avg_pool2d(x_0, [kernel_size, kernel_size], stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)
x_0 = self.downsample_x0(x_0).unsqueeze(2)
if x_ndim == 5:
x_0 = rearrange(x_0, 'b c t h w -> (b h w) t c')
x = self.temporal_attn_x0(x, context=x_0, rel_pos_bias = time_rel_pos_bias, framerate = framerate) + x
if self.use_temp_att:
x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias, framerate = framerate) + x
if x_ndim == 5:
x = rearrange(x, '(b h w) t c -> b c t h w', w = w, h = h)
x = self.ff(x, enable_time=enable_time) + x
if self.use_frame_shift:
x = self.frameshiftblock(x)
if self.use_resnet:
x = self.resblock(x)
return x