""" https://github.com/lucidrains/make-a-video-pytorch """ import math import functools from operator import mul import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange, repeat, pack, unpack from einops.layers.torch import Rearrange from .modules_conv import avg_pool_nd, zero_module, normalization, conv_nd # helper functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d def mul_reduce(tup): return functools.reduce(mul, tup) def divisible_by(numer, denom): return (numer % denom) == 0 mlist = nn.ModuleList # for time conditioning class SinusoidalPosEmb(nn.Module): def __init__(self, dim, theta = 10000): super().__init__() self.theta = theta self.dim = dim def forward(self, x): dtype, device = x.dtype, x.device assert dtype == torch.float, 'input to sinusoidal pos emb must be a float type' half_dim = self.dim // 2 emb = math.log(self.theta) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb) emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype) class ChanLayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.g = nn.Parameter(torch.ones(dim, 1, 1, 1)) def forward(self, x): eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) x = (x - mean) * var.clamp(min = eps).rsqrt() dtype = self.g.dtype return x.to(dtype) * self.g def shift_token(t): t, t_shift = t.chunk(2, dim = 1) t_shift = F.pad(t_shift, (0, 0, 0, 0, 1, -1), value = 0.) return torch.cat((t, t_shift), dim = 1) class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim = 1, unbiased = False, keepdim = True) mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) * var.clamp(min = eps).rsqrt() * self.g # feedforward class GEGLU(nn.Module): def forward(self, x): x = x.float() x, gate = x.chunk(2, dim = 1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, mult = 4): super().__init__() inner_dim = int(dim * mult * 2 / 3) self.proj_in = nn.Sequential( nn.Conv3d(dim, inner_dim * 2, 1, bias = False), GEGLU() ) self.proj_out = nn.Sequential( ChanLayerNorm(inner_dim), nn.Conv3d(inner_dim, dim, 1, bias = False) ) def forward(self, x, enable_time=True): x = self.proj_in(x) if enable_time: x = shift_token(x) return self.proj_out(x) # feedforwa # best relative positional encoding class ContinuousPositionBias(nn.Module): """ from https://arxiv.org/abs/2111.09883 """ def __init__( self, *, dim, heads, num_dims = 1, layers = 2, log_dist = True, cache_rel_pos = False ): super().__init__() self.num_dims = num_dims self.log_dist = log_dist self.net = nn.ModuleList([]) self.net.append(nn.Sequential(nn.Linear(self.num_dims, dim), nn.SiLU())) for _ in range(layers - 1): self.net.append(nn.Sequential(nn.Linear(dim, dim), nn.SiLU())) self.net.append(nn.Linear(dim, heads)) self.cache_rel_pos = cache_rel_pos self.register_buffer('rel_pos', None, persistent = False) @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype def forward(self, *dimensions): device = self.device if not exists(self.rel_pos) or not self.cache_rel_pos: positions = [torch.arange(d, device = device) for d in dimensions] grid = torch.stack(torch.meshgrid(*positions, indexing = 'ij')) grid = rearrange(grid, 'c ... -> (...) c') rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c') if self.log_dist: rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1) self.register_buffer('rel_pos', rel_pos, persistent = False) rel_pos = self.rel_pos.to(self.dtype) for layer in self.net: rel_pos = layer(rel_pos) return rearrange(rel_pos, 'i j h -> h i j') # helper classes class Attention(nn.Module): def __init__( self, dim, dim_head = 64, heads = 8 ): super().__init__() self.heads = heads self.scale = dim_head ** -0.5 inner_dim = dim_head * heads self.norm = LayerNorm(dim) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) nn.init.zeros_(self.to_out.weight.data) # identity with skip connection self.pos_embeds = nn.Parameter(torch.randn([1, 30, dim])) self.frame_rate_embeds = nn.Parameter(torch.randn([1, 30, dim])) def forward( self, x, context = None, rel_pos_bias = None, framerate = None, ): if framerate is not None: x = x + self.pos_embeds[:, :x.shape[1]].repeat(x.shape[0], 1, 1) x = x + self.frame_rate_embeds[:, framerate-1:framerate].repeat(x.shape[0], x.shape[1], 1) if context is None: context = x x = self.norm(x) context = self.norm(context) q, k, v = self.to_q(x), *self.to_kv(context).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) q = q * self.scale sim = einsum('b h i d, b h j d -> b h i j', q, k) if exists(rel_pos_bias): sim = sim + rel_pos_bias attn = sim.softmax(dim = -1) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) # main contribution - pseudo 3d conv class PseudoConv3d(nn.Module): def __init__( self, dim, dim_out = None, kernel_size = 3, *, temporal_kernel_size = None, **kwargs ): super().__init__() dim_out = default(dim_out, dim) temporal_kernel_size = default(temporal_kernel_size, kernel_size) self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2) self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2) if kernel_size > 1 else None if exists(self.temporal_conv): nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity nn.init.zeros_(self.temporal_conv.bias.data) def forward( self, x, enable_time = True ): b, c, *_, h, w = x.shape is_video = x.ndim == 5 enable_time &= is_video if is_video: x = rearrange(x, 'b c t h w -> (b t) c h w') x = self.spatial_conv(x) if is_video: x = rearrange(x, '(b t) c h w -> b c t h w', b = b) if not enable_time or not exists(self.temporal_conv): return x x = rearrange(x, 'b c t h w -> (b h w) c t') x = self.temporal_conv(x) x = rearrange(x, '(b h w) c t -> b c t h w', h = h, w = w) return x def frame_shift(x, shift_num=8): num_frame = x.shape[2] x = list(x.chunk(shift_num, 1)) for i in range(shift_num): if i > 0: shifted = torch.cat([torch.zeros_like(x[i][:, :, :i]), x[i][:, :, :-i]], 2) else: shifted = x[i] x[i] = shifted return torch.cat(x, 1) class ResBlockFrameShift(nn.Module): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels, dropout, out_channels=None, use_conv=False, dims=2, use_checkpoint=False, up=False, down=False, ): super().__init__() self.channels = channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.out_layers = nn.Sequential( normalization(self.channels), nn.SiLU(), zero_module( conv_nd(dims, self.channels, self.out_channels, 3, padding=1) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, 3, padding=1 ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :return: an [N x C x ...] Tensor of outputs. """ num_frames = x.shape[2] x = rearrange(x, 'b c t h w -> (b t) c h w') h = self.out_layers(x) h = rearrange(h, '(b t) c h w -> b c t h w', t=num_frames) h = frame_shift(h) h = rearrange(h, 'b c t h w -> (b t) c h w') out = self.skip_connection(x) + h out = rearrange(out, '(b t) c h w -> b c t h w', t=num_frames) return out class ResBlockVideo(nn.Module): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, use_checkpoint=False, up=False, down=False, ): super().__init__() self.channels = channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, 3, padding=1 ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :return: an [N x C x ...] Tensor of outputs. """ num_frames = x.shape[2] x = rearrange(x, 'b c t h w -> (b t) c h w ') h = x h = self.in_layers(h) h = self.out_layers(h) out = self.skip_connection(x) + h out = rearrange(out, '(b t) c h w -> b c t h w', t=num_frames) return out class Downsample3D(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, stride=None, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 1 if use_conv: self.op = conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding ) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class SpatioTemporalAttention(nn.Module): def __init__( self, dim, *, dim_head = 64, heads = 8, use_resnet = False, use_frame_shift = True, use_context_att = False, use_temp_att = True, use_context = False, ): super().__init__() self.use_resnet = use_resnet self.use_frame_shift = use_frame_shift self.use_context_att = use_context_att self.use_temp_att = use_temp_att if use_resnet: self.resblock = ResBlockVideo(dim, dropout=0, dims=2) if use_frame_shift: self.frameshiftblock = ResBlockFrameShift(dim, dropout=0, dims=2) if use_context_att: self.downsample_x0 = Downsample3D(4, True, 2, out_channels=dim) self.temporal_attn_x0 = Attention(dim = dim, dim_head = dim_head, heads = heads) if use_temp_att: self.temporal_attn = Attention(dim = dim, dim_head = dim_head, heads = heads) self.temporal_rel_pos_bias = ContinuousPositionBias(dim = dim // 2, heads = heads, num_dims = 1) self.ff = FeedForward(dim = dim, mult = 4) def forward( self, x, x_0 = None, enable_time = True, framerate = 4, is_video = False, ): x_ndim = x.ndim is_video = x_ndim == 5 or is_video enable_time &= is_video if enable_time: img_size = x.shape[-1] if self.use_temp_att: if x_ndim == 5: b, c, *_, h, w = x.shape x = rearrange(x, 'b c t h w -> (b h w) t c') time_rel_pos_bias = self.temporal_rel_pos_bias(x.shape[1]) if self.use_context_att and x_0 is not None: x_0_img_size = x_0.shape[-1] kernel_size = x_0_img_size // img_size x_0 = F.avg_pool2d(x_0, [kernel_size, kernel_size], stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None) x_0 = self.downsample_x0(x_0).unsqueeze(2) if x_ndim == 5: x_0 = rearrange(x_0, 'b c t h w -> (b h w) t c') x = self.temporal_attn_x0(x, context=x_0, rel_pos_bias = time_rel_pos_bias, framerate = framerate) + x if self.use_temp_att: x = self.temporal_attn(x, rel_pos_bias = time_rel_pos_bias, framerate = framerate) + x if x_ndim == 5: x = rearrange(x, '(b h w) t c -> b c t h w', w = w, h = h) x = self.ff(x, enable_time=enable_time) + x if self.use_frame_shift: x = self.frameshiftblock(x) if self.use_resnet: x = self.resblock(x) return x