|
import math |
|
import copy |
|
import operator |
|
import functools |
|
from typing import List |
|
from tqdm.auto import tqdm |
|
from functools import partial, wraps |
|
from contextlib import contextmanager, nullcontext |
|
from collections import namedtuple |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, einsum |
|
|
|
from einops import rearrange, repeat, reduce, pack, unpack |
|
from einops.layers.torch import Rearrange, Reduce |
|
from einops_exts.torch import EinopsToAndFrom |
|
|
|
from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME |
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def identity(t, *args, **kwargs): |
|
return t |
|
|
|
def first(arr, d = None): |
|
if len(arr) == 0: |
|
return d |
|
return arr[0] |
|
|
|
def divisible_by(numer, denom): |
|
return (numer % denom) == 0 |
|
|
|
def maybe(fn): |
|
@wraps(fn) |
|
def inner(x): |
|
if not exists(x): |
|
return x |
|
return fn(x) |
|
return inner |
|
|
|
def once(fn): |
|
called = False |
|
@wraps(fn) |
|
def inner(x): |
|
nonlocal called |
|
if called: |
|
return |
|
called = True |
|
return fn(x) |
|
return inner |
|
|
|
print_once = once(print) |
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if callable(d) else d |
|
|
|
def cast_tuple(val, length = None): |
|
if isinstance(val, list): |
|
val = tuple(val) |
|
|
|
output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) |
|
|
|
if exists(length): |
|
assert len(output) == length |
|
|
|
return output |
|
|
|
def cast_uint8_images_to_float(images): |
|
if not images.dtype == torch.uint8: |
|
return images |
|
return images / 255 |
|
|
|
def module_device(module): |
|
return next(module.parameters()).device |
|
|
|
def zero_init_(m): |
|
nn.init.zeros_(m.weight) |
|
if exists(m.bias): |
|
nn.init.zeros_(m.bias) |
|
|
|
def eval_decorator(fn): |
|
def inner(model, *args, **kwargs): |
|
was_training = model.training |
|
model.eval() |
|
out = fn(model, *args, **kwargs) |
|
model.train(was_training) |
|
return out |
|
return inner |
|
|
|
def pad_tuple_to_length(t, length, fillvalue = None): |
|
remain_length = length - len(t) |
|
if remain_length <= 0: |
|
return t |
|
return (*t, *((fillvalue,) * remain_length)) |
|
|
|
|
|
|
|
class Identity(nn.Module): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
def Sequential(*modules): |
|
return nn.Sequential(*filter(exists, modules)) |
|
|
|
|
|
|
|
def log(t, eps: float = 1e-12): |
|
return torch.log(t.clamp(min = eps)) |
|
|
|
def l2norm(t): |
|
return F.normalize(t, dim = -1) |
|
|
|
def right_pad_dims_to(x, t): |
|
padding_dims = x.ndim - t.ndim |
|
if padding_dims <= 0: |
|
return t |
|
return t.view(*t.shape, *((1,) * padding_dims)) |
|
|
|
def masked_mean(t, *, dim, mask = None): |
|
if not exists(mask): |
|
return t.mean(dim = dim) |
|
|
|
denom = mask.sum(dim = dim, keepdim = True) |
|
mask = rearrange(mask, 'b n -> b n 1') |
|
masked_t = t.masked_fill(~mask, 0.) |
|
|
|
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) |
|
|
|
def resize_video_to( |
|
video, |
|
target_image_size, |
|
target_frames = None, |
|
clamp_range = None, |
|
mode = 'nearest' |
|
): |
|
orig_video_size = video.shape[-1] |
|
|
|
frames = video.shape[2] |
|
target_frames = default(target_frames, frames) |
|
|
|
target_shape = (target_frames, target_image_size, target_image_size) |
|
|
|
if tuple(video.shape[-3:]) == target_shape: |
|
return video |
|
|
|
out = F.interpolate(video, target_shape, mode = mode) |
|
|
|
if exists(clamp_range): |
|
out = out.clamp(*clamp_range) |
|
|
|
return out |
|
|
|
def scale_video_time( |
|
video, |
|
downsample_scale = 1, |
|
mode = 'nearest' |
|
): |
|
if downsample_scale == 1: |
|
return video |
|
|
|
image_size, frames = video.shape[-1], video.shape[-3] |
|
assert divisible_by(frames, downsample_scale), f'trying to temporally downsample a conditioning video frames of length {frames} by {downsample_scale}, however it is not neatly divisible' |
|
|
|
target_frames = frames // downsample_scale |
|
|
|
resized_video = resize_video_to( |
|
video, |
|
image_size, |
|
target_frames = target_frames, |
|
mode = mode |
|
) |
|
|
|
return resized_video |
|
|
|
|
|
|
|
def prob_mask_like(shape, prob, device): |
|
if prob == 1: |
|
return torch.ones(shape, device = device, dtype = torch.bool) |
|
elif prob == 0: |
|
return torch.zeros(shape, device = device, dtype = torch.bool) |
|
else: |
|
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob |
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
def __init__(self, dim, stable = False): |
|
super().__init__() |
|
self.stable = stable |
|
self.g = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
if self.stable: |
|
x = x / x.amax(dim = -1, keepdim = True).detach() |
|
|
|
eps = 1e-5 if x.dtype == torch.float32 else 1e-3 |
|
var = torch.var(x, dim = -1, unbiased = False, keepdim = True) |
|
mean = torch.mean(x, dim = -1, keepdim = True) |
|
return (x - mean) * (var + eps).rsqrt() * self.g |
|
|
|
class ChanLayerNorm(nn.Module): |
|
def __init__(self, dim, stable = False): |
|
super().__init__() |
|
self.stable = stable |
|
self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) |
|
|
|
def forward(self, x): |
|
if self.stable: |
|
x = x / x.amax(dim = 1, keepdim = True).detach() |
|
|
|
eps = 1e-5 if x.dtype == torch.float32 else 1e-3 |
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True) |
|
mean = torch.mean(x, dim = 1, keepdim = True) |
|
return (x - mean) * (var + eps).rsqrt() * self.g |
|
|
|
class Always(): |
|
def __init__(self, val): |
|
self.val = val |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.val |
|
|
|
class Residual(nn.Module): |
|
def __init__(self, fn): |
|
super().__init__() |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
return self.fn(x, **kwargs) + x |
|
|
|
class Parallel(nn.Module): |
|
def __init__(self, *fns): |
|
super().__init__() |
|
self.fns = nn.ModuleList(fns) |
|
|
|
def forward(self, x): |
|
outputs = [fn(x) for fn in self.fns] |
|
return sum(outputs) |
|
|
|
|
|
|
|
class RearrangeTimeCentric(nn.Module): |
|
def __init__(self, fn): |
|
super().__init__() |
|
self.fn = fn |
|
|
|
def forward(self, x): |
|
x = rearrange(x, 'b c f ... -> b ... f c') |
|
x, ps = pack([x], '* f c') |
|
|
|
x = self.fn(x) |
|
|
|
x, = unpack(x, ps, '* f c') |
|
x = rearrange(x, 'b ... f c -> b c f ...') |
|
return x |
|
|
|
|
|
|
|
class PerceiverAttention(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
dim_head = 64, |
|
heads = 8, |
|
scale = 8 |
|
): |
|
super().__init__() |
|
self.scale = scale |
|
|
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
|
|
self.norm = nn.LayerNorm(dim) |
|
self.norm_latents = nn.LayerNorm(dim) |
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias = False) |
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) |
|
|
|
self.q_scale = nn.Parameter(torch.ones(dim_head)) |
|
self.k_scale = nn.Parameter(torch.ones(dim_head)) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, dim, bias = False), |
|
nn.LayerNorm(dim) |
|
) |
|
|
|
def forward(self, x, latents, mask = None): |
|
x = self.norm(x) |
|
latents = self.norm_latents(latents) |
|
|
|
b, h = x.shape[0], self.heads |
|
|
|
q = self.to_q(latents) |
|
|
|
|
|
kv_input = torch.cat((x, latents), dim = -2) |
|
k, v = self.to_kv(kv_input).chunk(2, dim = -1) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) |
|
|
|
|
|
|
|
q, k = map(l2norm, (q, k)) |
|
q = q * self.q_scale |
|
k = k * self.k_scale |
|
|
|
|
|
|
|
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale |
|
|
|
if exists(mask): |
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
mask = F.pad(mask, (0, latents.shape[-2]), value = True) |
|
mask = rearrange(mask, 'b j -> b 1 1 j') |
|
sim = sim.masked_fill(~mask, max_neg_value) |
|
|
|
|
|
|
|
attn = sim.softmax(dim = -1) |
|
|
|
out = einsum('... i j, ... j d -> ... i d', attn, v) |
|
out = rearrange(out, 'b h n d -> b n (h d)', h = h) |
|
return self.to_out(out) |
|
|
|
class PerceiverResampler(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
depth, |
|
dim_head = 64, |
|
heads = 8, |
|
num_latents = 64, |
|
num_latents_mean_pooled = 4, |
|
max_seq_len = 512, |
|
ff_mult = 4 |
|
): |
|
super().__init__() |
|
self.pos_emb = nn.Embedding(max_seq_len, dim) |
|
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
|
|
|
self.to_latents_from_mean_pooled_seq = None |
|
|
|
if num_latents_mean_pooled > 0: |
|
self.to_latents_from_mean_pooled_seq = nn.Sequential( |
|
LayerNorm(dim), |
|
nn.Linear(dim, dim * num_latents_mean_pooled), |
|
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) |
|
) |
|
|
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
self.layers.append(nn.ModuleList([ |
|
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads), |
|
FeedForward(dim = dim, mult = ff_mult) |
|
])) |
|
|
|
def forward(self, x, mask = None): |
|
n, device = x.shape[1], x.device |
|
pos_emb = self.pos_emb(torch.arange(n, device = device)) |
|
|
|
x_with_pos = x + pos_emb |
|
|
|
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) |
|
|
|
if exists(self.to_latents_from_mean_pooled_seq): |
|
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) |
|
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) |
|
latents = torch.cat((meanpooled_latents, latents), dim = -2) |
|
|
|
for attn, ff in self.layers: |
|
latents = attn(x_with_pos, latents, mask = mask) + latents |
|
latents = ff(latents) + latents |
|
|
|
return latents |
|
|
|
|
|
|
|
|
|
class Conv3d(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out = None, |
|
kernel_size = 3, |
|
*, |
|
temporal_kernel_size = None, |
|
**kwargs |
|
): |
|
super().__init__() |
|
dim_out = default(dim_out, dim) |
|
temporal_kernel_size = default(temporal_kernel_size, kernel_size) |
|
|
|
self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2) |
|
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None |
|
self.kernel_size = kernel_size |
|
|
|
if exists(self.temporal_conv): |
|
nn.init.dirac_(self.temporal_conv.weight.data) |
|
nn.init.zeros_(self.temporal_conv.bias.data) |
|
|
|
def forward( |
|
self, |
|
x, |
|
ignore_time = False |
|
): |
|
b, c, *_, h, w = x.shape |
|
|
|
is_video = x.ndim == 5 |
|
ignore_time &= is_video |
|
|
|
if is_video: |
|
x = rearrange(x, 'b c f h w -> (b f) c h w') |
|
|
|
x = self.spatial_conv(x) |
|
|
|
if is_video: |
|
x = rearrange(x, '(b f) c h w -> b c f h w', b = b) |
|
|
|
if ignore_time or not exists(self.temporal_conv): |
|
return x |
|
|
|
x = rearrange(x, 'b c f h w -> (b h w) c f') |
|
|
|
|
|
|
|
if self.kernel_size > 1: |
|
x = F.pad(x, (self.kernel_size - 1, 0)) |
|
|
|
x = self.temporal_conv(x) |
|
|
|
x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w) |
|
|
|
return x |
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
dim_head = 64, |
|
heads = 8, |
|
causal = False, |
|
context_dim = None, |
|
rel_pos_bias = False, |
|
rel_pos_bias_mlp_depth = 2, |
|
init_zero = False, |
|
scale = 8 |
|
): |
|
super().__init__() |
|
self.scale = scale |
|
self.causal = causal |
|
|
|
self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None |
|
|
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
|
|
self.norm = LayerNorm(dim) |
|
|
|
self.null_attn_bias = nn.Parameter(torch.randn(heads)) |
|
|
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) |
|
self.to_q = nn.Linear(dim, inner_dim, bias = False) |
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) |
|
|
|
self.q_scale = nn.Parameter(torch.ones(dim_head)) |
|
self.k_scale = nn.Parameter(torch.ones(dim_head)) |
|
|
|
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, dim, bias = False), |
|
LayerNorm(dim) |
|
) |
|
|
|
if init_zero: |
|
nn.init.zeros_(self.to_out[-1].g) |
|
|
|
def forward( |
|
self, |
|
x, |
|
context = None, |
|
mask = None, |
|
attn_bias = None |
|
): |
|
b, n, device = *x.shape[:2], x.device |
|
|
|
x = self.norm(x) |
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) |
|
|
|
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) |
|
|
|
|
|
|
|
nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2)) |
|
k = torch.cat((nk, k), dim = -2) |
|
v = torch.cat((nv, v), dim = -2) |
|
|
|
|
|
|
|
if exists(context): |
|
assert exists(self.to_context) |
|
ck, cv = self.to_context(context).chunk(2, dim = -1) |
|
k = torch.cat((ck, k), dim = -2) |
|
v = torch.cat((cv, v), dim = -2) |
|
|
|
|
|
|
|
q, k = map(l2norm, (q, k)) |
|
q = q * self.q_scale |
|
k = k * self.k_scale |
|
|
|
|
|
|
|
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale |
|
|
|
|
|
|
|
if not exists(attn_bias) and exists(self.rel_pos_bias): |
|
attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype) |
|
|
|
if exists(attn_bias): |
|
null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n) |
|
attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1) |
|
sim = sim + attn_bias |
|
|
|
|
|
|
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
|
|
if self.causal: |
|
i, j = sim.shape[-2:] |
|
causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) |
|
sim = sim.masked_fill(causal_mask, max_neg_value) |
|
|
|
if exists(mask): |
|
mask = F.pad(mask, (1, 0), value = True) |
|
mask = rearrange(mask, 'b j -> b 1 1 j') |
|
sim = sim.masked_fill(~mask, max_neg_value) |
|
|
|
|
|
|
|
attn = sim.softmax(dim = -1) |
|
|
|
|
|
|
|
out = einsum('b h i j, b j d -> b h i d', attn, v) |
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
return self.to_out(out) |
|
|
|
|
|
|
|
def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs): |
|
kernel = cast_tuple(kernel, 2) |
|
stride = cast_tuple(stride, 2) |
|
padding = cast_tuple(padding, 2) |
|
|
|
if len(kernel) == 2: |
|
kernel = (1, *kernel) |
|
|
|
if len(stride) == 2: |
|
stride = (1, *stride) |
|
|
|
if len(padding) == 2: |
|
padding = (0, *padding) |
|
|
|
return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs) |
|
|
|
class Pad(nn.Module): |
|
def __init__(self, padding, value = 0.): |
|
super().__init__() |
|
self.padding = padding |
|
self.value = value |
|
|
|
def forward(self, x): |
|
return F.pad(x, self.padding, value = self.value) |
|
|
|
|
|
|
|
def Upsample(dim, dim_out = None): |
|
dim_out = default(dim_out, dim) |
|
|
|
return nn.Sequential( |
|
nn.Upsample(scale_factor = 2, mode = 'nearest'), |
|
Conv2d(dim, dim_out, 3, padding = 1) |
|
) |
|
|
|
class PixelShuffleUpsample(nn.Module): |
|
def __init__(self, dim, dim_out = None): |
|
super().__init__() |
|
dim_out = default(dim_out, dim) |
|
conv = Conv2d(dim, dim_out * 4, 1) |
|
|
|
self.net = nn.Sequential( |
|
conv, |
|
nn.SiLU() |
|
) |
|
|
|
self.pixel_shuffle = nn.PixelShuffle(2) |
|
|
|
self.init_conv_(conv) |
|
|
|
def init_conv_(self, conv): |
|
o, i, f, h, w = conv.weight.shape |
|
conv_weight = torch.empty(o // 4, i, f, h, w) |
|
nn.init.kaiming_uniform_(conv_weight) |
|
conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') |
|
|
|
conv.weight.data.copy_(conv_weight) |
|
nn.init.zeros_(conv.bias.data) |
|
|
|
def forward(self, x): |
|
out = self.net(x) |
|
frames = x.shape[2] |
|
out = rearrange(out, 'b c f h w -> (b f) c h w') |
|
out = self.pixel_shuffle(out) |
|
return rearrange(out, '(b f) c h w -> b c f h w', f = frames) |
|
|
|
def Downsample(dim, dim_out = None): |
|
dim_out = default(dim_out, dim) |
|
return nn.Sequential( |
|
Rearrange('b c f (h p1) (w p2) -> b (c p1 p2) f h w', p1 = 2, p2 = 2), |
|
Conv2d(dim * 4, dim_out, 1) |
|
) |
|
|
|
|
|
|
|
class TemporalPixelShuffleUpsample(nn.Module): |
|
def __init__(self, dim, dim_out = None, stride = 2): |
|
super().__init__() |
|
self.stride = stride |
|
dim_out = default(dim_out, dim) |
|
conv = nn.Conv1d(dim, dim_out * stride, 1) |
|
|
|
self.net = nn.Sequential( |
|
conv, |
|
nn.SiLU() |
|
) |
|
|
|
self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride) |
|
|
|
self.init_conv_(conv) |
|
|
|
def init_conv_(self, conv): |
|
o, i, f = conv.weight.shape |
|
conv_weight = torch.empty(o // self.stride, i, f) |
|
nn.init.kaiming_uniform_(conv_weight) |
|
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride) |
|
|
|
conv.weight.data.copy_(conv_weight) |
|
nn.init.zeros_(conv.bias.data) |
|
|
|
def forward(self, x): |
|
b, c, f, h, w = x.shape |
|
x = rearrange(x, 'b c f h w -> (b h w) c f') |
|
out = self.net(x) |
|
out = self.pixel_shuffle(out) |
|
return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w) |
|
|
|
def TemporalDownsample(dim, dim_out = None, stride = 2): |
|
dim_out = default(dim_out, dim) |
|
return nn.Sequential( |
|
Rearrange('b c (f p) h w -> b (c p) f h w', p = stride), |
|
Conv2d(dim * stride, dim_out, 1) |
|
) |
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
half_dim = self.dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) |
|
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') |
|
return torch.cat((emb.sin(), emb.cos()), dim = -1) |
|
|
|
class LearnedSinusoidalPosEmb(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
assert (dim % 2) == 0 |
|
half_dim = dim // 2 |
|
self.weights = nn.Parameter(torch.randn(half_dim)) |
|
|
|
def forward(self, x): |
|
x = rearrange(x, 'b -> b 1') |
|
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi |
|
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) |
|
fouriered = torch.cat((x, fouriered), dim = -1) |
|
return fouriered |
|
|
|
class Block(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out, |
|
groups = 8, |
|
norm = True |
|
): |
|
super().__init__() |
|
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() |
|
self.activation = nn.SiLU() |
|
self.project = Conv3d(dim, dim_out, 3, padding = 1) |
|
|
|
def forward( |
|
self, |
|
x, |
|
scale_shift = None, |
|
ignore_time = False |
|
): |
|
x = self.groupnorm(x) |
|
|
|
if exists(scale_shift): |
|
scale, shift = scale_shift |
|
x = x * (scale + 1) + shift |
|
|
|
x = self.activation(x) |
|
return self.project(x, ignore_time = ignore_time) |
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out, |
|
*, |
|
cond_dim = None, |
|
time_cond_dim = None, |
|
groups = 8, |
|
linear_attn = False, |
|
use_gca = False, |
|
squeeze_excite = False, |
|
**attn_kwargs |
|
): |
|
super().__init__() |
|
|
|
self.time_mlp = None |
|
|
|
if exists(time_cond_dim): |
|
self.time_mlp = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(time_cond_dim, dim_out * 2) |
|
) |
|
|
|
self.cross_attn = None |
|
|
|
if exists(cond_dim): |
|
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention |
|
|
|
self.cross_attn = attn_klass( |
|
dim = dim_out, |
|
context_dim = cond_dim, |
|
**attn_kwargs |
|
) |
|
|
|
self.block1 = Block(dim, dim_out, groups = groups) |
|
self.block2 = Block(dim_out, dim_out, groups = groups) |
|
|
|
self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) |
|
|
|
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity() |
|
|
|
|
|
def forward( |
|
self, |
|
x, |
|
time_emb = None, |
|
cond = None, |
|
ignore_time = False |
|
): |
|
|
|
scale_shift = None |
|
if exists(self.time_mlp) and exists(time_emb): |
|
time_emb = self.time_mlp(time_emb) |
|
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') |
|
scale_shift = time_emb.chunk(2, dim = 1) |
|
|
|
h = self.block1(x, ignore_time = ignore_time) |
|
|
|
if exists(self.cross_attn): |
|
assert exists(cond) |
|
h = rearrange(h, 'b c ... -> b ... c') |
|
h, ps = pack([h], 'b * c') |
|
|
|
h = self.cross_attn(h, context = cond) + h |
|
|
|
h, = unpack(h, ps, 'b * c') |
|
h = rearrange(h, 'b ... c -> b c ...') |
|
|
|
h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time) |
|
|
|
h = h * self.gca(h) |
|
|
|
return h + self.res_conv(x) |
|
|
|
class CrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
context_dim = None, |
|
dim_head = 64, |
|
heads = 8, |
|
norm_context = False, |
|
scale = 8 |
|
): |
|
super().__init__() |
|
self.scale = scale |
|
|
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
|
|
context_dim = default(context_dim, dim) |
|
|
|
self.norm = LayerNorm(dim) |
|
self.norm_context = LayerNorm(context_dim) if norm_context else Identity() |
|
|
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head)) |
|
self.to_q = nn.Linear(dim, inner_dim, bias = False) |
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) |
|
|
|
self.q_scale = nn.Parameter(torch.ones(dim_head)) |
|
self.k_scale = nn.Parameter(torch.ones(dim_head)) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, dim, bias = False), |
|
LayerNorm(dim) |
|
) |
|
|
|
def forward(self, x, context, mask = None): |
|
b, n, device = *x.shape[:2], x.device |
|
|
|
x = self.norm(x) |
|
context = self.norm_context(context) |
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) |
|
|
|
|
|
|
|
nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) |
|
|
|
k = torch.cat((nk, k), dim = -2) |
|
v = torch.cat((nv, v), dim = -2) |
|
|
|
|
|
|
|
q, k = map(l2norm, (q, k)) |
|
q = q * self.q_scale |
|
k = k * self.k_scale |
|
|
|
|
|
|
|
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale |
|
|
|
|
|
|
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
|
|
if exists(mask): |
|
mask = F.pad(mask, (1, 0), value = True) |
|
mask = rearrange(mask, 'b j -> b 1 1 j') |
|
sim = sim.masked_fill(~mask, max_neg_value) |
|
|
|
attn = sim.softmax(dim = -1, dtype = torch.float32) |
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v) |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
return self.to_out(out) |
|
|
|
class LinearCrossAttention(CrossAttention): |
|
def forward(self, x, context, mask = None): |
|
b, n, device = *x.shape[:2], x.device |
|
|
|
x = self.norm(x) |
|
context = self.norm_context(context) |
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = self.heads), (q, k, v)) |
|
|
|
|
|
|
|
nk, nv = map(lambda t: repeat(t, 'd -> (b h) 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2)) |
|
|
|
k = torch.cat((nk, k), dim = -2) |
|
v = torch.cat((nv, v), dim = -2) |
|
|
|
|
|
|
|
max_neg_value = -torch.finfo(x.dtype).max |
|
|
|
if exists(mask): |
|
mask = F.pad(mask, (1, 0), value = True) |
|
mask = rearrange(mask, 'b n -> b n 1') |
|
k = k.masked_fill(~mask, max_neg_value) |
|
v = v.masked_fill(~mask, 0.) |
|
|
|
|
|
|
|
q = q.softmax(dim = -1) |
|
k = k.softmax(dim = -2) |
|
|
|
q = q * self.scale |
|
|
|
context = einsum('b n d, b n e -> b d e', k, v) |
|
out = einsum('b n d, b d e -> b n e', q, context) |
|
out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) |
|
return self.to_out(out) |
|
|
|
class LinearAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_head = 32, |
|
heads = 8, |
|
dropout = 0.05, |
|
context_dim = None, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
self.norm = ChanLayerNorm(dim) |
|
|
|
self.nonlin = nn.SiLU() |
|
|
|
self.to_q = nn.Sequential( |
|
nn.Dropout(dropout), |
|
Conv2d(dim, inner_dim, 1, bias = False), |
|
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) |
|
) |
|
|
|
self.to_k = nn.Sequential( |
|
nn.Dropout(dropout), |
|
Conv2d(dim, inner_dim, 1, bias = False), |
|
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) |
|
) |
|
|
|
self.to_v = nn.Sequential( |
|
nn.Dropout(dropout), |
|
Conv2d(dim, inner_dim, 1, bias = False), |
|
Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) |
|
) |
|
|
|
self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None |
|
|
|
self.to_out = nn.Sequential( |
|
Conv2d(inner_dim, dim, 1, bias = False), |
|
ChanLayerNorm(dim) |
|
) |
|
|
|
def forward(self, fmap, context = None): |
|
h, x, y = self.heads, *fmap.shape[-2:] |
|
|
|
fmap = self.norm(fmap) |
|
q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) |
|
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v)) |
|
|
|
if exists(context): |
|
assert exists(self.to_context) |
|
ck, cv = self.to_context(context).chunk(2, dim = -1) |
|
ck, cv = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (ck, cv)) |
|
k = torch.cat((k, ck), dim = -2) |
|
v = torch.cat((v, cv), dim = -2) |
|
|
|
q = q.softmax(dim = -1) |
|
k = k.softmax(dim = -2) |
|
|
|
q = q * self.scale |
|
|
|
context = einsum('b n d, b n e -> b d e', k, v) |
|
out = einsum('b n d, b d e -> b n e', q, context) |
|
out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) |
|
|
|
out = self.nonlin(out) |
|
return self.to_out(out) |
|
|
|
class GlobalContext(nn.Module): |
|
""" basically a superior form of squeeze-excitation that is attention-esque """ |
|
|
|
def __init__( |
|
self, |
|
*, |
|
dim_in, |
|
dim_out |
|
): |
|
super().__init__() |
|
self.to_k = Conv2d(dim_in, 1, 1) |
|
hidden_dim = max(3, dim_out // 2) |
|
|
|
self.net = nn.Sequential( |
|
Conv2d(dim_in, hidden_dim, 1), |
|
nn.SiLU(), |
|
Conv2d(hidden_dim, dim_out, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
context = self.to_k(x) |
|
x, context = map(lambda t: rearrange(t, 'b n ... -> b n (...)'), (x, context)) |
|
out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) |
|
out = rearrange(out, '... -> ... 1 1') |
|
return self.net(out) |
|
|
|
def FeedForward(dim, mult = 2): |
|
hidden_dim = int(dim * mult) |
|
return nn.Sequential( |
|
LayerNorm(dim), |
|
nn.Linear(dim, hidden_dim, bias = False), |
|
nn.GELU(), |
|
LayerNorm(hidden_dim), |
|
nn.Linear(hidden_dim, dim, bias = False) |
|
) |
|
|
|
class TimeTokenShift(nn.Module): |
|
def forward(self, x): |
|
if x.ndim != 5: |
|
return x |
|
|
|
x, x_shift = x.chunk(2, dim = 1) |
|
x_shift = F.pad(x_shift, (0, 0, 0, 0, 1, -1), value = 0.) |
|
return torch.cat((x, x_shift), dim = 1) |
|
|
|
def ChanFeedForward(dim, mult = 2, time_token_shift = True): |
|
hidden_dim = int(dim * mult) |
|
return Sequential( |
|
ChanLayerNorm(dim), |
|
Conv2d(dim, hidden_dim, 1, bias = False), |
|
nn.GELU(), |
|
TimeTokenShift() if time_token_shift else None, |
|
ChanLayerNorm(hidden_dim), |
|
Conv2d(hidden_dim, dim, 1, bias = False) |
|
) |
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
depth = 1, |
|
heads = 8, |
|
dim_head = 32, |
|
ff_mult = 2, |
|
ff_time_token_shift = True, |
|
context_dim = None |
|
): |
|
super().__init__() |
|
self.layers = nn.ModuleList([]) |
|
|
|
for _ in range(depth): |
|
self.layers.append(nn.ModuleList([ |
|
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), |
|
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) |
|
])) |
|
|
|
def forward(self, x, context = None): |
|
for attn, ff in self.layers: |
|
x = rearrange(x, 'b c ... -> b ... c') |
|
x, ps = pack([x], 'b * c') |
|
|
|
x = attn(x, context = context) + x |
|
|
|
x, = unpack(x, ps, 'b * c') |
|
x = rearrange(x, 'b ... c -> b c ...') |
|
|
|
x = ff(x) + x |
|
return x |
|
|
|
class LinearAttentionTransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
depth = 1, |
|
heads = 8, |
|
dim_head = 32, |
|
ff_mult = 2, |
|
ff_time_token_shift = True, |
|
context_dim = None, |
|
**kwargs |
|
): |
|
super().__init__() |
|
self.layers = nn.ModuleList([]) |
|
|
|
for _ in range(depth): |
|
self.layers.append(nn.ModuleList([ |
|
LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), |
|
ChanFeedForward(dim = dim, mult = ff_mult, time_token_shift = ff_time_token_shift) |
|
])) |
|
|
|
def forward(self, x, context = None): |
|
for attn, ff in self.layers: |
|
x = attn(x, context = context) + x |
|
x = ff(x) + x |
|
return x |
|
|
|
class CrossEmbedLayer(nn.Module): |
|
def __init__( |
|
self, |
|
dim_in, |
|
kernel_sizes, |
|
dim_out = None, |
|
stride = 2 |
|
): |
|
super().__init__() |
|
assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) |
|
dim_out = default(dim_out, dim_in) |
|
|
|
kernel_sizes = sorted(kernel_sizes) |
|
num_scales = len(kernel_sizes) |
|
|
|
|
|
dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] |
|
dim_scales = [*dim_scales, dim_out - sum(dim_scales)] |
|
|
|
self.convs = nn.ModuleList([]) |
|
for kernel, dim_scale in zip(kernel_sizes, dim_scales): |
|
self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) |
|
|
|
def forward(self, x): |
|
fmaps = tuple(map(lambda conv: conv(x), self.convs)) |
|
return torch.cat(fmaps, dim = 1) |
|
|
|
class UpsampleCombiner(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
enabled = False, |
|
dim_ins = tuple(), |
|
dim_outs = tuple() |
|
): |
|
super().__init__() |
|
dim_outs = cast_tuple(dim_outs, len(dim_ins)) |
|
assert len(dim_ins) == len(dim_outs) |
|
|
|
self.enabled = enabled |
|
|
|
if not self.enabled: |
|
self.dim_out = dim |
|
return |
|
|
|
self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) |
|
self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) |
|
|
|
def forward(self, x, fmaps = None): |
|
target_size = x.shape[-1] |
|
|
|
fmaps = default(fmaps, tuple()) |
|
|
|
if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: |
|
return x |
|
|
|
fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps] |
|
outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] |
|
return torch.cat((x, *outs), dim = 1) |
|
|
|
class DynamicPositionBias(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
*, |
|
heads, |
|
depth |
|
): |
|
super().__init__() |
|
self.mlp = nn.ModuleList([]) |
|
|
|
self.mlp.append(nn.Sequential( |
|
nn.Linear(1, dim), |
|
LayerNorm(dim), |
|
nn.SiLU() |
|
)) |
|
|
|
for _ in range(max(depth - 1, 0)): |
|
self.mlp.append(nn.Sequential( |
|
nn.Linear(dim, dim), |
|
LayerNorm(dim), |
|
nn.SiLU() |
|
)) |
|
|
|
self.mlp.append(nn.Linear(dim, heads)) |
|
|
|
def forward(self, n, device, dtype): |
|
i = torch.arange(n, device = device) |
|
j = torch.arange(n, device = device) |
|
|
|
indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j') |
|
indices += (n - 1) |
|
|
|
pos = torch.arange(-n + 1, n, device = device, dtype = dtype) |
|
pos = rearrange(pos, '... -> ... 1') |
|
|
|
for layer in self.mlp: |
|
pos = layer(pos) |
|
|
|
bias = pos[indices] |
|
bias = rearrange(bias, 'i j h -> h i j') |
|
return bias |
|
|
|
class Unet3D(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME), |
|
num_resnet_blocks = 1, |
|
cond_dim = None, |
|
num_image_tokens = 4, |
|
num_time_tokens = 2, |
|
learned_sinu_pos_emb_dim = 16, |
|
out_dim = None, |
|
dim_mults = (1, 2, 4, 8), |
|
temporal_strides = 1, |
|
cond_images_channels = 0, |
|
channels = 3, |
|
channels_out = None, |
|
attn_dim_head = 64, |
|
attn_heads = 8, |
|
ff_mult = 2., |
|
ff_time_token_shift = True, |
|
lowres_cond = False, |
|
layer_attns = False, |
|
layer_attns_depth = 1, |
|
layer_attns_add_text_cond = True, |
|
attend_at_middle = True, |
|
time_rel_pos_bias_depth = 2, |
|
time_causal_attn = True, |
|
layer_cross_attns = True, |
|
use_linear_attn = False, |
|
use_linear_cross_attn = False, |
|
cond_on_text = True, |
|
max_text_len = 256, |
|
init_dim = None, |
|
resnet_groups = 8, |
|
init_conv_kernel_size = 7, |
|
init_cross_embed = True, |
|
init_cross_embed_kernel_sizes = (3, 7, 15), |
|
cross_embed_downsample = False, |
|
cross_embed_downsample_kernel_sizes = (2, 4), |
|
attn_pool_text = True, |
|
attn_pool_num_latents = 32, |
|
dropout = 0., |
|
memory_efficient = False, |
|
init_conv_to_final_conv_residual = False, |
|
use_global_context_attn = True, |
|
scale_skip_connection = True, |
|
final_resnet_block = True, |
|
final_conv_kernel_size = 3, |
|
self_cond = False, |
|
combine_upsample_fmaps = False, |
|
pixel_shuffle_upsample = True, |
|
resize_mode = 'nearest' |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' |
|
|
|
if dim < 128: |
|
print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') |
|
|
|
|
|
|
|
self._locals = locals() |
|
self._locals.pop('self', None) |
|
self._locals.pop('__class__', None) |
|
|
|
self.self_cond = self_cond |
|
|
|
|
|
|
|
self.channels = channels |
|
self.channels_out = default(channels_out, channels) |
|
|
|
|
|
|
|
init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) |
|
init_dim = default(init_dim, dim) |
|
|
|
|
|
|
|
self.has_cond_image = cond_images_channels > 0 |
|
self.cond_images_channels = cond_images_channels |
|
|
|
init_channels += cond_images_channels |
|
|
|
|
|
|
|
self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) |
|
|
|
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] |
|
in_out = list(zip(dims[:-1], dims[1:])) |
|
|
|
|
|
|
|
cond_dim = default(cond_dim, dim) |
|
time_cond_dim = dim * 4 * (2 if lowres_cond else 1) |
|
|
|
|
|
|
|
sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) |
|
sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 |
|
|
|
self.to_time_hiddens = nn.Sequential( |
|
sinu_pos_emb, |
|
nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), |
|
nn.SiLU() |
|
) |
|
|
|
self.to_time_cond = nn.Sequential( |
|
nn.Linear(time_cond_dim, time_cond_dim) |
|
) |
|
|
|
|
|
|
|
self.to_time_tokens = nn.Sequential( |
|
nn.Linear(time_cond_dim, cond_dim * num_time_tokens), |
|
Rearrange('b (r d) -> b r d', r = num_time_tokens) |
|
) |
|
|
|
|
|
|
|
self.lowres_cond = lowres_cond |
|
|
|
if lowres_cond: |
|
self.to_lowres_time_hiddens = nn.Sequential( |
|
LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), |
|
nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), |
|
nn.SiLU() |
|
) |
|
|
|
self.to_lowres_time_cond = nn.Sequential( |
|
nn.Linear(time_cond_dim, time_cond_dim) |
|
) |
|
|
|
self.to_lowres_time_tokens = nn.Sequential( |
|
nn.Linear(time_cond_dim, cond_dim * num_time_tokens), |
|
Rearrange('b (r d) -> b r d', r = num_time_tokens) |
|
) |
|
|
|
|
|
|
|
self.norm_cond = nn.LayerNorm(cond_dim) |
|
|
|
|
|
|
|
self.text_to_cond = None |
|
|
|
if cond_on_text: |
|
assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' |
|
self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) |
|
|
|
|
|
|
|
self.cond_on_text = cond_on_text |
|
|
|
|
|
|
|
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None |
|
|
|
|
|
|
|
self.max_text_len = max_text_len |
|
|
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) |
|
self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) |
|
|
|
|
|
|
|
self.to_text_non_attn_cond = None |
|
|
|
if cond_on_text: |
|
self.to_text_non_attn_cond = nn.Sequential( |
|
nn.LayerNorm(cond_dim), |
|
nn.Linear(cond_dim, time_cond_dim), |
|
nn.SiLU(), |
|
nn.Linear(time_cond_dim, time_cond_dim) |
|
) |
|
|
|
|
|
|
|
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head) |
|
|
|
num_layers = len(in_out) |
|
|
|
|
|
|
|
temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1) |
|
temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim))) |
|
|
|
temporal_attn = lambda dim: RearrangeTimeCentric(Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn, 'init_zero': True, 'rel_pos_bias': True}))) |
|
|
|
|
|
|
|
num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) |
|
resnet_groups = cast_tuple(resnet_groups, num_layers) |
|
|
|
resnet_klass = partial(ResnetBlock, **attn_kwargs) |
|
|
|
layer_attns = cast_tuple(layer_attns, num_layers) |
|
layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) |
|
layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) |
|
|
|
assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) |
|
|
|
|
|
|
|
temporal_strides = cast_tuple(temporal_strides, num_layers) |
|
self.total_temporal_divisor = functools.reduce(operator.mul, temporal_strides, 1) |
|
|
|
|
|
|
|
downsample_klass = Downsample |
|
|
|
if cross_embed_downsample: |
|
downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) |
|
|
|
|
|
|
|
self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None |
|
|
|
self.init_temporal_peg = temporal_peg(init_dim) |
|
self.init_temporal_attn = temporal_attn(init_dim) |
|
|
|
|
|
|
|
self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) |
|
|
|
|
|
|
|
self.downs = nn.ModuleList([]) |
|
self.ups = nn.ModuleList([]) |
|
num_resolutions = len(in_out) |
|
|
|
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides] |
|
reversed_layer_params = list(map(reversed, layer_params)) |
|
|
|
|
|
|
|
skip_connect_dims = [] |
|
|
|
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)): |
|
is_last = ind >= (num_resolutions - 1) |
|
|
|
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn |
|
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None |
|
|
|
transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity) |
|
|
|
current_dim = dim_in |
|
|
|
|
|
|
|
pre_downsample = None |
|
|
|
if memory_efficient: |
|
pre_downsample = downsample_klass(dim_in, dim_out) |
|
current_dim = dim_out |
|
|
|
skip_connect_dims.append(current_dim) |
|
|
|
|
|
|
|
post_downsample = None |
|
if not memory_efficient: |
|
post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1)) |
|
|
|
self.downs.append(nn.ModuleList([ |
|
pre_downsample, |
|
resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), |
|
nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), |
|
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), |
|
temporal_peg(current_dim), |
|
temporal_attn(current_dim), |
|
TemporalDownsample(current_dim, stride = temporal_stride) if temporal_stride > 1 else None, |
|
post_downsample |
|
])) |
|
|
|
|
|
|
|
mid_dim = dims[-1] |
|
|
|
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) |
|
self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None |
|
self.mid_temporal_peg = temporal_peg(mid_dim) |
|
self.mid_temporal_attn = temporal_attn(mid_dim) |
|
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) |
|
|
|
|
|
|
|
upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample |
|
|
|
|
|
|
|
upsample_fmap_dims = [] |
|
|
|
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)): |
|
is_last = ind == (len(in_out) - 1) |
|
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn |
|
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None |
|
transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity) |
|
|
|
skip_connect_dim = skip_connect_dims.pop() |
|
|
|
upsample_fmap_dims.append(dim_out) |
|
|
|
self.ups.append(nn.ModuleList([ |
|
resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), |
|
nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), |
|
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, ff_time_token_shift = ff_time_token_shift, context_dim = cond_dim, **attn_kwargs), |
|
temporal_peg(dim_out), |
|
temporal_attn(dim_out), |
|
TemporalPixelShuffleUpsample(dim_out, stride = temporal_stride) if temporal_stride > 1 else None, |
|
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() |
|
])) |
|
|
|
|
|
|
|
self.upsample_combiner = UpsampleCombiner( |
|
dim = dim, |
|
enabled = combine_upsample_fmaps, |
|
dim_ins = upsample_fmap_dims, |
|
dim_outs = dim |
|
) |
|
|
|
|
|
|
|
self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual |
|
final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) |
|
|
|
|
|
|
|
self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None |
|
|
|
final_conv_dim_in = dim if final_resnet_block else final_conv_dim |
|
final_conv_dim_in += (channels if lowres_cond else 0) |
|
|
|
self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) |
|
|
|
zero_init_(self.final_conv) |
|
|
|
|
|
|
|
self.resize_mode = resize_mode |
|
|
|
|
|
|
|
def cast_model_parameters( |
|
self, |
|
*, |
|
lowres_cond, |
|
text_embed_dim, |
|
channels, |
|
channels_out, |
|
cond_on_text |
|
): |
|
if lowres_cond == self.lowres_cond and \ |
|
channels == self.channels and \ |
|
cond_on_text == self.cond_on_text and \ |
|
text_embed_dim == self._locals['text_embed_dim'] and \ |
|
channels_out == self.channels_out: |
|
return self |
|
|
|
updated_kwargs = dict( |
|
lowres_cond = lowres_cond, |
|
text_embed_dim = text_embed_dim, |
|
channels = channels, |
|
channels_out = channels_out, |
|
cond_on_text = cond_on_text |
|
) |
|
|
|
return self.__class__(**{**self._locals, **updated_kwargs}) |
|
|
|
|
|
|
|
def to_config_and_state_dict(self): |
|
return self._locals, self.state_dict() |
|
|
|
|
|
|
|
@classmethod |
|
def from_config_and_state_dict(klass, config, state_dict): |
|
unet = klass(**config) |
|
unet.load_state_dict(state_dict) |
|
return unet |
|
|
|
|
|
|
|
def persist_to_file(self, path): |
|
path = Path(path) |
|
path.parents[0].mkdir(exist_ok = True, parents = True) |
|
|
|
config, state_dict = self.to_config_and_state_dict() |
|
pkg = dict(config = config, state_dict = state_dict) |
|
torch.save(pkg, str(path)) |
|
|
|
|
|
|
|
@classmethod |
|
def hydrate_from_file(klass, path): |
|
path = Path(path) |
|
assert path.exists() |
|
pkg = torch.load(str(path)) |
|
|
|
assert 'config' in pkg and 'state_dict' in pkg |
|
config, state_dict = pkg['config'], pkg['state_dict'] |
|
|
|
return Unet.from_config_and_state_dict(config, state_dict) |
|
|
|
|
|
|
|
def forward_with_cond_scale( |
|
self, |
|
*args, |
|
cond_scale = 1., |
|
**kwargs |
|
): |
|
logits = self.forward(*args, **kwargs) |
|
|
|
if cond_scale == 1: |
|
return logits |
|
|
|
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) |
|
return null_logits + (logits - null_logits) * cond_scale |
|
|
|
def forward( |
|
self, |
|
x, |
|
time, |
|
*, |
|
lowres_cond_img = None, |
|
lowres_noise_times = None, |
|
text_embeds = None, |
|
text_mask = None, |
|
cond_images = None, |
|
cond_video_frames = None, |
|
post_cond_video_frames = None, |
|
self_cond = None, |
|
cond_drop_prob = 0., |
|
ignore_time = False |
|
): |
|
assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)' |
|
|
|
batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype |
|
|
|
assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames {frames} must be divisible by {self.total_temporal_divisor}' |
|
|
|
|
|
|
|
if self.self_cond: |
|
self_cond = default(self_cond, lambda: torch.zeros_like(x)) |
|
x = torch.cat((x, self_cond), dim = 1) |
|
|
|
|
|
|
|
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' |
|
assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' |
|
|
|
if exists(lowres_cond_img): |
|
x = torch.cat((x, lowres_cond_img), dim = 1) |
|
|
|
if exists(cond_video_frames): |
|
lowres_cond_img = torch.cat((cond_video_frames, lowres_cond_img), dim = 2) |
|
cond_video_frames = torch.cat((cond_video_frames, cond_video_frames), dim = 1) |
|
|
|
if exists(post_cond_video_frames): |
|
lowres_cond_img = torch.cat((lowres_cond_img, post_cond_video_frames), dim = 2) |
|
post_cond_video_frames = torch.cat((post_cond_video_frames, post_cond_video_frames), dim = 1) |
|
|
|
|
|
|
|
num_preceding_frames = 0 |
|
if exists(cond_video_frames): |
|
cond_video_frames_len = cond_video_frames.shape[2] |
|
|
|
assert divisible_by(cond_video_frames_len, self.total_temporal_divisor) |
|
|
|
cond_video_frames = resize_video_to(cond_video_frames, x.shape[-1]) |
|
x = torch.cat((cond_video_frames, x), dim = 2) |
|
|
|
num_preceding_frames = cond_video_frames_len |
|
|
|
|
|
|
|
num_succeeding_frames = 0 |
|
if exists(post_cond_video_frames): |
|
cond_video_frames_len = post_cond_video_frames.shape[2] |
|
|
|
assert divisible_by(cond_video_frames_len, self.total_temporal_divisor) |
|
|
|
post_cond_video_frames = resize_video_to(post_cond_video_frames, x.shape[-1]) |
|
x = torch.cat((post_cond_video_frames, x), dim = 2) |
|
|
|
num_succeeding_frames = cond_video_frames_len |
|
|
|
|
|
|
|
assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' |
|
|
|
if exists(cond_images): |
|
assert cond_images.ndim == 4, 'conditioning images must have 4 dimensions only, if you want to condition on frames of video, use `cond_video_frames` instead' |
|
assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' |
|
|
|
cond_images = repeat(cond_images, 'b c h w -> b c f h w', f = x.shape[2]) |
|
cond_images = resize_video_to(cond_images, x.shape[-1], mode = self.resize_mode) |
|
|
|
x = torch.cat((cond_images, x), dim = 1) |
|
|
|
|
|
|
|
conv_kwargs = dict( |
|
ignore_time = ignore_time |
|
) |
|
|
|
|
|
|
|
x = self.init_conv(x) |
|
|
|
if not ignore_time: |
|
x = self.init_temporal_peg(x) |
|
x = self.init_temporal_attn(x) |
|
|
|
|
|
|
|
if self.init_conv_to_final_conv_residual: |
|
init_conv_residual = x.clone() |
|
|
|
|
|
|
|
time_hiddens = self.to_time_hiddens(time) |
|
|
|
|
|
|
|
time_tokens = self.to_time_tokens(time_hiddens) |
|
t = self.to_time_cond(time_hiddens) |
|
|
|
|
|
|
|
|
|
if self.lowres_cond: |
|
lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) |
|
lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) |
|
lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) |
|
|
|
t = t + lowres_t |
|
time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) |
|
|
|
|
|
|
|
text_tokens = None |
|
|
|
if exists(text_embeds) and self.cond_on_text: |
|
|
|
|
|
|
|
text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) |
|
|
|
text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') |
|
text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') |
|
|
|
|
|
|
|
text_tokens = self.text_to_cond(text_embeds) |
|
|
|
text_tokens = text_tokens[:, :self.max_text_len] |
|
|
|
if exists(text_mask): |
|
text_mask = text_mask[:, :self.max_text_len] |
|
|
|
text_tokens_len = text_tokens.shape[1] |
|
remainder = self.max_text_len - text_tokens_len |
|
|
|
if remainder > 0: |
|
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) |
|
|
|
if exists(text_mask): |
|
if remainder > 0: |
|
text_mask = F.pad(text_mask, (0, remainder), value = False) |
|
|
|
text_mask = rearrange(text_mask, 'b n -> b n 1') |
|
text_keep_mask_embed = text_mask & text_keep_mask_embed |
|
|
|
null_text_embed = self.null_text_embed.to(text_tokens.dtype) |
|
|
|
text_tokens = torch.where( |
|
text_keep_mask_embed, |
|
text_tokens, |
|
null_text_embed |
|
) |
|
|
|
if exists(self.attn_pool): |
|
text_tokens = self.attn_pool(text_tokens) |
|
|
|
|
|
|
|
|
|
mean_pooled_text_tokens = text_tokens.mean(dim = -2) |
|
|
|
text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) |
|
|
|
null_text_hidden = self.null_text_hidden.to(t.dtype) |
|
|
|
text_hiddens = torch.where( |
|
text_keep_mask_hidden, |
|
text_hiddens, |
|
null_text_hidden |
|
) |
|
|
|
t = t + text_hiddens |
|
|
|
|
|
|
|
c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) |
|
|
|
|
|
|
|
c = self.norm_cond(c) |
|
|
|
|
|
|
|
if exists(self.init_resnet_block): |
|
x = self.init_resnet_block(x, t, **conv_kwargs) |
|
|
|
|
|
|
|
hiddens = [] |
|
|
|
for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_downsample, post_downsample in self.downs: |
|
if exists(pre_downsample): |
|
x = pre_downsample(x) |
|
|
|
x = init_block(x, t, c, **conv_kwargs) |
|
|
|
for resnet_block in resnet_blocks: |
|
x = resnet_block(x, t, **conv_kwargs) |
|
hiddens.append(x) |
|
|
|
x = attn_block(x, c) |
|
|
|
if not ignore_time: |
|
x = temporal_peg(x) |
|
x = temporal_attn(x) |
|
|
|
hiddens.append(x) |
|
|
|
if exists(temporal_downsample) and not ignore_time: |
|
x = temporal_downsample(x) |
|
|
|
if exists(post_downsample): |
|
x = post_downsample(x) |
|
|
|
x = self.mid_block1(x, t, c, **conv_kwargs) |
|
|
|
if exists(self.mid_attn): |
|
x = self.mid_attn(x) |
|
|
|
if not ignore_time: |
|
x = self.mid_temporal_peg(x) |
|
x = self.mid_temporal_attn(x) |
|
|
|
x = self.mid_block2(x, t, c, **conv_kwargs) |
|
|
|
add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) |
|
|
|
up_hiddens = [] |
|
|
|
for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_upsample, upsample in self.ups: |
|
if exists(temporal_upsample) and not ignore_time: |
|
x = temporal_upsample(x) |
|
|
|
x = add_skip_connection(x) |
|
x = init_block(x, t, c, **conv_kwargs) |
|
|
|
for resnet_block in resnet_blocks: |
|
x = add_skip_connection(x) |
|
x = resnet_block(x, t, **conv_kwargs) |
|
|
|
x = attn_block(x, c) |
|
|
|
if not ignore_time: |
|
x = temporal_peg(x) |
|
x = temporal_attn(x) |
|
|
|
up_hiddens.append(x.contiguous()) |
|
|
|
x = upsample(x) |
|
|
|
|
|
|
|
x = self.upsample_combiner(x, up_hiddens) |
|
|
|
|
|
|
|
if self.init_conv_to_final_conv_residual: |
|
x = torch.cat((x, init_conv_residual), dim = 1) |
|
|
|
if exists(self.final_res_block): |
|
x = self.final_res_block(x, t, **conv_kwargs) |
|
|
|
if exists(lowres_cond_img): |
|
x = torch.cat((x, lowres_cond_img), dim = 1) |
|
|
|
out = self.final_conv(x) |
|
|
|
if num_preceding_frames > 0: |
|
out = out[:, :, num_preceding_frames:] |
|
|
|
if num_succeeding_frames > 0: |
|
out = out[:, :, :-num_succeeding_frames] |
|
|
|
return out |
|
|