bluestyle97's picture
Upload 147 files
184193d verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from inspect import isfunction
from einops import rearrange, repeat
import xformers.ops as xops
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim, bias=False),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
out = xops.memory_efficient_attention(q, k, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):
super().__init__()
self.self_attn = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)
self.ff = nn.Sequential(
nn.Linear(dim, dim*4, bias=False),
nn.GELU(),
nn.Linear(dim*4, dim, bias=False),
)
self.norm1 = nn.LayerNorm(dim, bias=False)
self.norm2 = nn.LayerNorm(dim, bias=False)
def forward(self, x, context=None):
before_sa = self.norm1(x)
x = x + self.self_attn(before_sa)
x = self.ff(self.norm2(x)) + x
return x
class Transformer(nn.Module):
def __init__(
self,
image_size=512,
patch_size=8,
input_dim=3,
inner_dim=1024,
output_dim=14,
n_heads=16,
depth=24,
dropout=0.,
):
super().__init__()
self.patch_size = patch_size
self.input_dim = input_dim
self.inner_dim = inner_dim
self.output_dim = output_dim
self.patchify = nn.Conv2d(input_dim, inner_dim, kernel_size=patch_size, stride=patch_size, padding=0, bias=False)
num_patches = (image_size // patch_size) ** 2
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, inner_dim))
self.ref_embed = nn.Parameter(torch.zeros(1, 1, inner_dim))
self.src_embed = nn.Parameter(torch.zeros(1, 1, inner_dim))
self.blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, inner_dim//n_heads, dropout=dropout)
for _ in range(depth)]
)
self.norm = nn.LayerNorm(inner_dim, bias=False)
self.unpatchify = nn.Linear(inner_dim, patch_size ** 2 * output_dim, bias=True)
nn.init.trunc_normal_(self.pos_embed, std=.02)
nn.init.trunc_normal_(self.ref_embed, std=.02)
nn.init.trunc_normal_(self.src_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[-2]
N = self.pos_embed.shape[-2]
if npatch == N and w == h:
return self.pos_embed
patch_pos_embed = self.pos_embed
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = F.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2).contiguous(),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim).contiguous()
return patch_pos_embed
def forward(self, images):
"""
images: (B, N, C, H, W)
"""
B, N, _, H, W = images.shape
# patchify
images = rearrange(images, 'b n c h w -> (b n) c h w')
tokens = self.patchify(images)
tokens = rearrange(tokens, 'bn c h w -> bn (h w) c')
# add pos encodings
tokens = rearrange(tokens, '(b n) hw c -> b n hw c', b=B)
tokens = tokens + self.interpolate_pos_encoding(tokens, W, H).unsqueeze(1)
view_embeds = torch.cat([self.ref_embed, self.src_embed.repeat(1, N-1, 1)], dim=1)
tokens = tokens + view_embeds.unsqueeze(2)
# tokens = rearrange(tokens, '(b n) hw c -> b n hw c', b=B)
# tokens = tokens + self.interpolate_pos_encoding(tokens, W, H).unsqueeze(1)
# view_embeds = self.src_embed.repeat(1, N, 1)
# view_embeds[:, 0:1] = torch.zeros_like(self.ref_embed)
# tokens = tokens + view_embeds.unsqueeze(2)
# transformer
tokens = rearrange(tokens, 'b n hw c -> b (n hw) c')
x = tokens
for layer in self.blocks:
x = layer(x)
# unpatchify
x = self.norm(x)
x = self.unpatchify(x)
x = rearrange(x, 'b (n h w) c -> b n h w c', n=N, h=H//self.patch_size, w=W//self.patch_size)
x = rearrange(x, 'b n h w (p q c) -> b n (h p) (w q) c', p=self.patch_size, q=self.patch_size)
out = x
return out