Textvodeoslashai_v1 / imagen_video.py
sejamenath2023's picture
Upload 12 files
239ee43
import math
import copy
import operator
import functools
from typing import List
from tqdm.auto import tqdm
from functools import partial, wraps
from contextlib import contextmanager, nullcontext
from collections import namedtuple
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from einops_exts.torch import EinopsToAndFrom
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
# helper functions
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def first(arr, d = None):
if len(arr) == 0:
return d
return arr[0]
def divisible_by(numer, denom):
return (numer % denom) == 0
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def cast_tuple(val, length = None):
if isinstance(val, list):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if exists(length):
assert len(output) == length
return output
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
return images / 255
def module_device(module):
return next(module.parameters()).device
def zero_init_(m):
nn.init.zeros_(m.weight)
if exists(m.bias):
nn.init.zeros_(m.bias)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# helper classes
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
def Sequential(*modules):
return nn.Sequential(*filter(exists, modules))
# tensor helpers
def log(t, eps: float = 1e-12):
return torch.log(t.clamp(min = eps))
def l2norm(t):
return F.normalize(t, dim = -1)
def right_pad_dims_to(x, t):
padding_dims = x.ndim - t.ndim
if padding_dims <= 0:
return t
return t.view(*t.shape, *((1,) * padding_dims))
def masked_mean(t, *, dim, mask = None):
if not exists(mask):
return t.mean(dim = dim)
denom = mask.sum(dim = dim, keepdim = True)
mask = rearrange(mask, 'b n -> b n 1')
masked_t = t.masked_fill(~mask, 0.)
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
def resize_video_to(
video,
target_image_size,
target_frames = None,
clamp_range = None,
mode = 'nearest'
):
orig_video_size = video.shape[-1]
frames = video.shape[2]
target_frames = default(target_frames, frames)
target_shape = (target_frames, target_image_size, target_image_size)
if tuple(video.shape[-3:]) == target_shape:
return video
out = F.interpolate(video, target_shape, mode = mode)
if exists(clamp_range):
out = out.clamp(*clamp_range)
return out
def scale_video_time(
video,
downsample_scale = 1,
mode = 'nearest'
):
if downsample_scale == 1:
return video
image_size, frames = video.shape[-1], video.shape[-3]
assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible'
target_frames = frames // downsample_scale
resized_video = resize_video_to(
video,
image_size,
target_frames = target_frames,
mode = mode
)
return resized_video
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
# norms and residuals
class LayerNorm(nn.Module):
def __init__(self, dim, stable = False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class ChanLayerNorm(nn.Module):
def __init__(self, dim, stable = False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
def forward(self, x):
if self.stable:
x = x / x.amax(dim = 1, keepdim = True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * (var + eps).rsqrt() * self.g
class Always():
def __init__(self, val):
self.val = val
def __call__(self, *args, **kwargs):
return self.val
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class Parallel(nn.Module):
def __init__(self, *fns):
super().__init__()
self.fns = nn.ModuleList(fns)
def forward(self, x):
outputs = [fn(x) for fn in self.fns]
return sum(outputs)
# rearranging
class RearrangeTimeCentric(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
x = rearrange(x, 'b c f ... -> b ... f c')
x, ps = pack([x], '* f c')
x = self.fn(x)
x, = unpack(x, ps, '* f c')
x = rearrange(x, 'b ... f c -> b c f ...')
return x
# attention pooling
class PerceiverAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.LayerNorm(dim)
)
def forward(self, x, latents, mask = None):
x = self.norm(x)
latents = self.norm_latents(latents)
b, h = x.shape[0], self.heads
q = self.to_q(latents)
# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
kv_input = torch.cat((x, latents), dim = -2)
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# qk rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# similarities and masking
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# attention
attn = sim.softmax(dim = -1)
out = einsum('... i j, ... j d -> ... i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
num_latents = 64,
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
max_seq_len = 512,
ff_mult = 4
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.to_latents_from_mean_pooled_seq = None
if num_latents_mean_pooled > 0:
self.to_latents_from_mean_pooled_seq = nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))
def forward(self, x, mask = None):
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device = device))
x_with_pos = x + pos_emb
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
if exists(self.to_latents_from_mean_pooled_seq):
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim = -2)
for attn, ff in self.layers:
latents = attn(x_with_pos, latents, mask = mask) + latents
latents = ff(latents) + latents
return latents
# main contribution from make-a-video - pseudo conv3d
# axial space-time convolutions, but made causal to keep in line with the design decisions of imagen-video paper
class Conv3d(nn.Module):
def __init__(
self,
dim,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None
self.kernel_size = kernel_size
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)
def forward(
self,
x,
ignore_time = False
):
b, c, *_, h, w = x.shape
is_video = x.ndim == 5
ignore_time &= is_video
if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
x = self.spatial_conv(x)
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
if ignore_time or not exists(self.temporal_conv):
return x
x = rearrange(x, 'b c f h w -> (b h w) c f')
# causal temporal convolution - time is causal in imagen-video
if self.kernel_size > 1:
x = F.pad(x, (self.kernel_size - 1, 0))
x = self.temporal_conv(x)
x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
return x
# attention
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
causal = False,
context_dim = None,
rel_pos_bias = False,
rel_pos_bias_mlp_depth = 2,
init_zero = False,
scale = 8
):
super().__init__()
self.scale = scale
self.causal = causal
self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None
self.heads = heads
inner_dim = dim_head * heads
self.norm = LayerNorm(dim)
self.null_attn_bias = nn.Parameter(torch.randn(heads))
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
if init_zero:
nn.init.zeros_(self.to_out[-1].g)
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None
):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# add text conditioning, if present
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)
# qk rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale
# relative positional encoding (T5 style)
if not exists(attn_bias) and exists(self.rel_pos_bias):
attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype)
if exists(attn_bias):
null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
# attention
attn = sim.softmax(dim = -1)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
# pseudo conv2d that uses conv3d but with kernel size of 1 across frames dimension
def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
kernel = cast_tuple(kernel, 2)
stride = cast_tuple(stride, 2)
padding = cast_tuple(padding, 2)
if len(kernel) == 2:
kernel = (1, *kernel)
if len(stride) == 2:
stride = (1, *stride)
if len(padding) == 2:
padding = (0, *padding)
return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)
class Pad(nn.Module):
def __init__(self, padding, value = 0.):
super().__init__()
self.padding = padding
self.value = value
def forward(self, x):
return F.pad(x, self.padding, value = self.value)
# decoder
def Upsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
Conv2d(dim, dim_out, 3, padding = 1)
)
class PixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None):
super().__init__()
dim_out = default(dim_out, dim)
conv = Conv2d(dim, dim_out * 4, 1)
self.net = nn.Sequential(
conv,
nn.SiLU()
)
self.pixel_shuffle = nn.PixelShuffle(2)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, f, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, f, h, w)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
out = self.net(x)
frames = x.shape[2]
out = rearrange(out, 'b c f h w -> (b f) c h w')
out = self.pixel_shuffle(out)
return rearrange(out, '(b f) c h w -> b c f h w', f = frames)
def Downsample(dim, dim_out = None):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2),
Conv2d(dim * 4, dim_out, 1)
)
# temporal up and downsamples
class TemporalPixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None, stride = 2):
super().__init__()
self.stride = stride
dim_out = default(dim_out, dim)
conv = nn.Conv1d(dim, dim_out * stride, 1)
self.net = nn.Sequential(
conv,
nn.SiLU()
)
self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride)
self.init_conv_(conv)
def init_conv_(self, conv):
o, i, f = conv.weight.shape
conv_weight = torch.empty(o // self.stride, i, f)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride)
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def forward(self, x):
b, c, f, h, w = x.shape
x = rearrange(x, 'b c f h w -> (b h w) c f')
out = self.net(x)
out = self.pixel_shuffle(out)
return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w)
def TemporalDownsample(dim, dim_out = None, stride = 2):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = stride),
Conv2d(dim * stride, dim_out, 1)
)
# positional embedding
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class LearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
class Block(nn.Module):
def __init__(
self,
dim,
dim_out,
groups = 8,
norm = True
):
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
self.activation = nn.SiLU()
self.project = Conv3d(dim, dim_out, 3, padding = 1)
def forward(
self,
x,
scale_shift = None,
ignore_time = False
):
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x, ignore_time = ignore_time)
class ResnetBlock(nn.Module):
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
groups = 8,
linear_attn = False,
use_gca = False,
squeeze_excite = False,
**attn_kwargs
):
super().__init__()
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_cond_dim, dim_out * 2)
)
self.cross_attn = None
if exists(cond_dim):
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
self.cross_attn = attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
)
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
def forward(
self,
x,
time_emb = None,
cond = None,
ignore_time = False
):
scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, ignore_time = ignore_time)
if exists(self.cross_attn):
assert exists(cond)
h = rearrange(h, 'b c ... -> b ... c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b ... c -> b c ...')
h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time)
h = h * self.gca(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
norm_context = False,
scale = 8
):
super().__init__()
self.scale = scale
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = LayerNorm(dim)
self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# qk rmsnorm
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale
# similarities
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class LinearCrossAttention(CrossAttention):
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v))
# add null key / value for classifier free guidance in prior net
nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# masking
max_neg_value = -torch.finfo(x.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b n -> b n 1')
k = k.masked_fill(~mask, max_neg_value)
v = v.masked_fill(~mask, 0.)
# linear attention
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(
self,
dim,
dim_head = 32,
heads = 8,
dropout = 0.05,
context_dim = None,
**kwargs
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = ChanLayerNorm(dim)
self.nonlin = nn.SiLU()
self.to_q = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_k = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_v = nn.Sequential(
nn.Dropout(dropout),
Conv2d(dim, inner_dim, 1, bias = False),
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
)
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
self.to_out = nn.Sequential(
Conv2d(inner_dim, dim, 1, bias = False),
ChanLayerNorm(dim)
)
def forward(self, fmap, context = None):
h, x, y = self.heads, *fmap.shape[-2:]
fmap = self.norm(fmap)
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
if exists(context):
assert exists(self.to_context)
ck, cv = self.to_context(context).chunk(2, dim = -1)
ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv))
k = torch.cat((k, ck), dim = -2)
v = torch.cat((v, cv), dim = -2)
q = q.softmax(dim = -1)
k = k.softmax(dim = -2)
q = q * self.scale
context = einsum('b n d, b n e -> b d e', k, v)
out = einsum('b n d, b d e -> b n e', q, context)
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
out = self.nonlin(out)
return self.to_out(out)
class GlobalContext(nn.Module):
""" basically a superior form of squeeze-excitation that is attention-esque """
def __init__(
self,
*,
dim_in,
dim_out
):
super().__init__()
self.to_k = Conv2d(dim_in, 1, 1)
hidden_dim = max(3, dim_out // 2)
self.net = nn.Sequential(
Conv2d(dim_in, hidden_dim, 1),
nn.SiLU(),
Conv2d(hidden_dim, dim_out, 1),
nn.Sigmoid()
)
def forward(self, x):
context = self.to_k(x)
x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context))
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
out = rearrange(out, '... -> ... 1 1')
return self.net(out)
def FeedForward(dim, mult = 2):
hidden_dim = int(dim * mult)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, hidden_dim, bias = False),
nn.GELU(),
LayerNorm(hidden_dim),
nn.Linear(hidden_dim, dim, bias = False)
)
class TimeTokenShift(nn.Module):
def forward(self, x):
if x.ndim != 5:
return x
x, x_shift = x.chunk(2, dim = 1)
x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.)
return torch.cat((x, x_shift), dim = 1)
def ChanFeedForward(dim, mult = 2, time_token_shift = True): # in paper, it seems for self attention layers they did feedforwards with twice channel width
hidden_dim = int(dim * mult)
return Sequential(
ChanLayerNorm(dim),
Conv2d(dim, hidden_dim, 1, bias = False),
nn.GELU(),
TimeTokenShift() if time_token_shift else None,
ChanLayerNorm(hidden_dim),
Conv2d(hidden_dim, dim, 1, bias = False)
)
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = rearrange(x, 'b c ... -> b ... c')
x, ps = pack([x], 'b * c')
x = attn(x, context = context) + x
x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b ... c -> b c ...')
x = ff(x) + x
return x
class LinearAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
*,
depth = 1,
heads = 8,
dim_head = 32,
ff_mult = 2,
ff_time_token_shift = True,
context_dim = None,
**kwargs
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift)
]))
def forward(self, x, context = None):
for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x
return x
class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in,
kernel_sizes,
dim_out = None,
stride = 2
):
super().__init__()
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
dim_out = default(dim_out, dim_in)
kernel_sizes = sorted(kernel_sizes)
num_scales = len(kernel_sizes)
# calculate the dimension at each scale
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
self.convs = nn.ModuleList([])
for kernel, dim_scale in zip(kernel_sizes, dim_scales):
self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
def forward(self, x):
fmaps = tuple(map(lambda conv: conv(x), self.convs))
return torch.cat(fmaps, dim = 1)
class UpsampleCombiner(nn.Module):
def __init__(
self,
dim,
*,
enabled = False,
dim_ins = tuple(),
dim_outs = tuple()
):
super().__init__()
dim_outs = cast_tuple(dim_outs, len(dim_ins))
assert len(dim_ins) == len(dim_outs)
self.enabled = enabled
if not self.enabled:
self.dim_out = dim
return
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
def forward(self, x, fmaps = None):
target_size = x.shape[-1]
fmaps = default(fmaps, tuple())
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
return x
fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
return torch.cat((x, *outs), dim = 1)
class DynamicPositionBias(nn.Module):
def __init__(
self,
dim,
*,
heads,
depth
):
super().__init__()
self.mlp = nn.ModuleList([])
self.mlp.append(nn.Sequential(
nn.Linear(1, dim),
LayerNorm(dim),
nn.SiLU()
))
for _ in range(max(depth - 1, 0)):
self.mlp.append(nn.Sequential(
nn.Linear(dim, dim),
LayerNorm(dim),
nn.SiLU()
))
self.mlp.append(nn.Linear(dim, heads))
def forward(self, n, device, dtype):
i = torch.arange(n, device = device)
j = torch.arange(n, device = device)
indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
indices += (n - 1)
pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
pos = rearrange(pos, '... -> ... 1')
for layer in self.mlp:
pos = layer(pos)
bias = pos[indices]
bias = rearrange(bias, 'i j h -> h i j')
return bias
class Unet3D(nn.Module):
def __init__(
self,
*,
dim,
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
num_resnet_blocks = 1,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
learned_sinu_pos_emb_dim = 16,
out_dim = None,
dim_mults = (1, 2, 4, 8),
temporal_strides = 1,
cond_images_channels = 0,
channels = 3,
channels_out = None,
attn_dim_head = 64,
attn_heads = 8,
ff_mult = 2.,
ff_time_token_shift = True, # this would do a token shift along time axis, at the hidden layer within feedforwards - from successful use in RWKV (Peng et al), and other token shift video transformer works
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
layer_attns = False,
layer_attns_depth = 1,
layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
time_rel_pos_bias_depth = 2,
time_causal_attn = True,
layer_cross_attns = True,
use_linear_attn = False,
use_linear_cross_attn = False,
cond_on_text = True,
max_text_len = 256,
init_dim = None,
resnet_groups = 8,
init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
init_cross_embed = True,
init_cross_embed_kernel_sizes = (3, 7, 15),
cross_embed_downsample = False,
cross_embed_downsample_kernel_sizes = (2, 4),
attn_pool_text = True,
attn_pool_num_latents = 32,
dropout = 0.,
memory_efficient = False,
init_conv_to_final_conv_residual = False,
use_global_context_attn = True,
scale_skip_connection = True,
final_resnet_block = True,
final_conv_kernel_size = 3,
self_cond = False,
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
pixel_shuffle_upsample = True, # may address checkboard artifacts
resize_mode = 'nearest'
):
super().__init__()
# guide researchers
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
if dim < 128:
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
self._locals.pop('self', None)
self._locals.pop('__class__', None)
self.self_cond = self_cond
# determine dimensions
self.channels = channels
self.channels_out = default(channels_out, channels)
# (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
# (2) in self conditioning, one appends the predict x0 (x_start)
init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
init_dim = default(init_dim, dim)
# optional image conditioning
self.has_cond_image = cond_images_channels > 0
self.cond_images_channels = cond_images_channels
init_channels += cond_images_channels
# initial convolution
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time conditioning
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
# embedding time for log(snr) noise from continuous version
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
self.to_time_hiddens = nn.Sequential(
sinu_pos_emb,
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
nn.SiLU()
)
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
# project to time tokens as well as time hiddens
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# low res aug noise conditioning
self.lowres_cond = lowres_cond
if lowres_cond:
self.to_lowres_time_hiddens = nn.Sequential(
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
nn.SiLU()
)
self.to_lowres_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.to_lowres_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
# normalizations
self.norm_cond = nn.LayerNorm(cond_dim)
# text encoding conditioning (optional)
self.text_to_cond = None
if cond_on_text:
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
# finer control over whether to condition on text encodings
self.cond_on_text = cond_on_text
# attention pooling
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None
# for classifier free guidance
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
# for non-attention based text conditioning at all points in the network where time is also conditioned
self.to_text_non_attn_cond = None
if cond_on_text:
self.to_text_non_attn_cond = nn.Sequential(
nn.LayerNorm(cond_dim),
nn.Linear(cond_dim, time_cond_dim),
nn.SiLU(),
nn.Linear(time_cond_dim, time_cond_dim)
)
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
num_layers = len(in_out)
# temporal attention - attention across video frames
temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1)
temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim)))
temporal_attn = lambda dim: RearrangeTimeCentric(Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn, 'init_zero': True, 'rel_pos_bias': True})))
# resnet block klass
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
resnet_groups = cast_tuple(resnet_groups, num_layers)
resnet_klass = partial(ResnetBlock, **attn_kwargs)
layer_attns = cast_tuple(layer_attns, num_layers)
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
# temporal downsample config
temporal_strides = cast_tuple(temporal_strides, num_layers)
self.total_temporal_divisor = functools.reduce(operator.mul, temporal_strides, 1)
# downsample klass
downsample_klass = Downsample
if cross_embed_downsample:
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
# initial resnet block (for memory efficient unet)
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
self.init_temporal_peg = temporal_peg(init_dim)
self.init_temporal_attn = temporal_attn(init_dim)
# scale for resnet skip connections
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides]
reversed_layer_params = list(map(reversed, layer_params))
# downsampling layers
skip_connect_dims = [] # keep track of skip connection dimensions
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
current_dim = dim_in
# whether to pre-downsample, from memory efficient unet
pre_downsample = None
if memory_efficient:
pre_downsample = downsample_klass(dim_in, dim_out)
current_dim = dim_out
skip_connect_dims.append(current_dim)
# whether to do post-downsample, for non-memory efficient unet
post_downsample = None
if not memory_efficient:
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1))
self.downs.append(nn.ModuleList([
pre_downsample,
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
temporal_peg(current_dim),
temporal_attn(current_dim),
TemporalDownsample(current_dim, stride = temporal_stride) if temporal_stride > 1 else None,
post_downsample
]))
# middle layers
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_temporal_peg = temporal_peg(mid_dim)
self.mid_temporal_attn = temporal_attn(mid_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
# upsample klass
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
# upsampling layers
upsample_fmap_dims = []
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
skip_connect_dim = skip_connect_dims.pop()
upsample_fmap_dims.append(dim_out)
self.ups.append(nn.ModuleList([
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs),
temporal_peg(dim_out),
temporal_attn(dim_out),
TemporalPixelShuffleUpsample(dim_out, stride = temporal_stride) if temporal_stride > 1 else None,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
]))
# whether to combine feature maps from all upsample blocks before final resnet block out
self.upsample_combiner = UpsampleCombiner(
dim = dim,
enabled = combine_upsample_fmaps,
dim_ins = upsample_fmap_dims,
dim_outs = dim
)
# whether to do a final residual from initial conv to the final resnet block out
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
# final optional resnet block and convolution out
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
final_conv_dim_in = dim if final_resnet_block else final_conv_dim
final_conv_dim_in += (channels if lowres_cond else 0)
self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
zero_init_(self.final_conv)
# resize mode
self.resize_mode = resize_mode
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def cast_model_parameters(
self,
*,
lowres_cond,
text_embed_dim,
channels,
channels_out,
cond_on_text
):
if lowres_cond == self.lowres_cond and \
channels == self.channels and \
cond_on_text == self.cond_on_text and \
text_embed_dim == self._locals['text_embed_dim'] and \
channels_out == self.channels_out:
return self
updated_kwargs = dict(
lowres_cond = lowres_cond,
text_embed_dim = text_embed_dim,
channels = channels,
channels_out = channels_out,
cond_on_text = cond_on_text
)
return self.__class__(**{**self._locals, **updated_kwargs})
# methods for returning the full unet config as well as its parameter state
def to_config_and_state_dict(self):
return self._locals, self.state_dict()
# class method for rehydrating the unet from its config and state dict
@classmethod
def from_config_and_state_dict(klass, config, state_dict):
unet = klass(**config)
unet.load_state_dict(state_dict)
return unet
# methods for persisting unet to disk
def persist_to_file(self, path):
path = Path(path)
path.parents[0].mkdir(exist_ok = True, parents = True)
config, state_dict = self.to_config_and_state_dict()
pkg = dict(config = config, state_dict = state_dict)
torch.save(pkg, str(path))
# class method for rehydrating the unet from file saved with `persist_to_file`
@classmethod
def hydrate_from_file(klass, path):
path = Path(path)
assert path.exists()
pkg = torch.load(str(path))
assert 'config' in pkg and 'state_dict' in pkg
config, state_dict = pkg['config'], pkg['state_dict']
return Unet.from_config_and_state_dict(config, state_dict)
# forward with classifier free guidance
def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
*,
lowres_cond_img = None,
lowres_noise_times = None,
text_embeds = None,
text_mask = None,
cond_images = None,
cond_video_frames = None,
post_cond_video_frames = None,
self_cond = None,
cond_drop_prob = 0.,
ignore_time = False
):
assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)'
batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype
assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames {frames} must be divisible by {self.total_temporal_divisor}'
# add self conditioning if needed
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x, self_cond), dim = 1)
# add low resolution conditioning, if present
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
if exists(cond_video_frames):
lowres_cond_img = torch.cat((cond_video_frames, lowres_cond_img), dim = 2)
cond_video_frames = torch.cat((cond_video_frames, cond_video_frames), dim = 1)
if exists(post_cond_video_frames):
lowres_cond_img = torch.cat((lowres_cond_img, post_cond_video_frames), dim = 2)
post_cond_video_frames = torch.cat((post_cond_video_frames, post_cond_video_frames), dim = 1)
# conditioning on video frames as a prompt
num_preceding_frames = 0
if exists(cond_video_frames):
cond_video_frames_len = cond_video_frames.shape[2]
assert divisible_by(cond_video_frames_len, self.total_temporal_divisor)
cond_video_frames = resize_video_to(cond_video_frames, x.shape[-1])
x = torch.cat((cond_video_frames, x), dim = 2)
num_preceding_frames = cond_video_frames_len
# conditioning on video frames as a prompt
num_succeeding_frames = 0
if exists(post_cond_video_frames):
cond_video_frames_len = post_cond_video_frames.shape[2]
assert divisible_by(cond_video_frames_len, self.total_temporal_divisor)
post_cond_video_frames = resize_video_to(post_cond_video_frames, x.shape[-1])
x = torch.cat((post_cond_video_frames, x), dim = 2)
num_succeeding_frames = cond_video_frames_len
# condition on input image
assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
if exists(cond_images):
assert cond_images.ndim == 4, 'conditioning images must have 4 dimensions only, if you want to condition on frames of video, use `cond_video_frames` instead'
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
cond_images = repeat(cond_images, 'b c h w -> b c f h w', f = x.shape[2])
cond_images = resize_video_to(cond_images, x.shape[-1], mode = self.resize_mode)
x = torch.cat((cond_images, x), dim = 1)
# ignoring time in pseudo 3d resnet blocks
conv_kwargs = dict(
ignore_time = ignore_time
)
# initial convolution
x = self.init_conv(x)
if not ignore_time:
x = self.init_temporal_peg(x)
x = self.init_temporal_attn(x)
# init conv residual
if self.init_conv_to_final_conv_residual:
init_conv_residual = x.clone()
# time conditioning
time_hiddens = self.to_time_hiddens(time)
# derive time tokens
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# add lowres time conditioning to time hiddens
# and add lowres time tokens along sequence dimension for attention
if self.lowres_cond:
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
t = t + lowres_t
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
# text conditioning
text_tokens = None
if exists(text_embeds) and self.cond_on_text:
# conditional dropout
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
# calculate text embeds
text_tokens = self.text_to_cond(text_embeds)
text_tokens = text_tokens[:, :self.max_text_len]
if exists(text_mask):
text_mask = text_mask[:, :self.max_text_len]
text_tokens_len = text_tokens.shape[1]
remainder = self.max_text_len - text_tokens_len
if remainder > 0:
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
if exists(text_mask):
if remainder > 0:
text_mask = F.pad(text_mask, (0, remainder), value = False)
text_mask = rearrange(text_mask, 'b n -> b n 1')
text_keep_mask_embed = text_mask & text_keep_mask_embed
null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
text_tokens = torch.where(
text_keep_mask_embed,
text_tokens,
null_text_embed
)
if exists(self.attn_pool):
text_tokens = self.attn_pool(text_tokens)
# extra non-attention conditioning by projecting and then summing text embeddings to time
# termed as text hiddens
mean_pooled_text_tokens = text_tokens.mean(dim = -2)
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
null_text_hidden = self.null_text_hidden.to(t.dtype)
text_hiddens = torch.where(
text_keep_mask_hidden,
text_hiddens,
null_text_hidden
)
t = t + text_hiddens
# main conditioning tokens (c)
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
# normalize conditioning tokens
c = self.norm_cond(c)
# initial resnet block (for memory efficient unet)
if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t, **conv_kwargs)
# go through the layers of the unet, down and up
hiddens = []
for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_downsample, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)
x = init_block(x, t, c, **conv_kwargs)
for resnet_block in resnet_blocks:
x = resnet_block(x, t, **conv_kwargs)
hiddens.append(x)
x = attn_block(x, c)
if not ignore_time:
x = temporal_peg(x)
x = temporal_attn(x)
hiddens.append(x)
if exists(temporal_downsample) and not ignore_time:
x = temporal_downsample(x)
if exists(post_downsample):
x = post_downsample(x)
x = self.mid_block1(x, t, c, **conv_kwargs)
if exists(self.mid_attn):
x = self.mid_attn(x)
if not ignore_time:
x = self.mid_temporal_peg(x)
x = self.mid_temporal_attn(x)
x = self.mid_block2(x, t, c, **conv_kwargs)
add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
up_hiddens = []
for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_upsample, upsample in self.ups:
if exists(temporal_upsample) and not ignore_time:
x = temporal_upsample(x)
x = add_skip_connection(x)
x = init_block(x, t, c, **conv_kwargs)
for resnet_block in resnet_blocks:
x = add_skip_connection(x)
x = resnet_block(x, t, **conv_kwargs)
x = attn_block(x, c)
if not ignore_time:
x = temporal_peg(x)
x = temporal_attn(x)
up_hiddens.append(x.contiguous())
x = upsample(x)
# whether to combine all feature maps from upsample blocks
x = self.upsample_combiner(x, up_hiddens)
# final top-most residual if needed
if self.init_conv_to_final_conv_residual:
x = torch.cat((x, init_conv_residual), dim = 1)
if exists(self.final_res_block):
x = self.final_res_block(x, t, **conv_kwargs)
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
out = self.final_conv(x)
if num_preceding_frames > 0:
out = out[:, :, num_preceding_frames:]
if num_succeeding_frames > 0:
out = out[:, :, :-num_succeeding_frames]
return out