import math import torch import torch.nn as nn import torch.nn.functional as F from inspect import isfunction from einops import rearrange, repeat import xformers.ops as xops def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim, bias=False), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) out = xops.memory_efficient_attention(q, k, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True): super().__init__() self.self_attn = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) self.ff = nn.Sequential( nn.Linear(dim, dim*4, bias=False), nn.GELU(), nn.Linear(dim*4, dim, bias=False), ) self.norm1 = nn.LayerNorm(dim, bias=False) self.norm2 = nn.LayerNorm(dim, bias=False) def forward(self, x, context=None): before_sa = self.norm1(x) x = x + self.self_attn(before_sa) x = self.ff(self.norm2(x)) + x return x class Transformer(nn.Module): def __init__( self, image_size=512, patch_size=8, input_dim=3, inner_dim=1024, output_dim=14, n_heads=16, depth=24, dropout=0., ): super().__init__() self.patch_size = patch_size self.input_dim = input_dim self.inner_dim = inner_dim self.output_dim = output_dim self.patchify = nn.Conv2d(input_dim, inner_dim, kernel_size=patch_size, stride=patch_size, padding=0, bias=False) num_patches = (image_size // patch_size) ** 2 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, inner_dim)) self.ref_embed = nn.Parameter(torch.zeros(1, 1, inner_dim)) self.src_embed = nn.Parameter(torch.zeros(1, 1, inner_dim)) self.blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, inner_dim//n_heads, dropout=dropout) for _ in range(depth)] ) self.norm = nn.LayerNorm(inner_dim, bias=False) self.unpatchify = nn.Linear(inner_dim, patch_size ** 2 * output_dim, bias=True) nn.init.trunc_normal_(self.pos_embed, std=.02) nn.init.trunc_normal_(self.ref_embed, std=.02) nn.init.trunc_normal_(self.src_embed, std=.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1.0) if m.bias is not None: nn.init.constant_(m.bias, 0) def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[-2] N = self.pos_embed.shape[-2] if npatch == N and w == h: return self.pos_embed patch_pos_embed = self.pos_embed dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = F.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2).contiguous(), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim).contiguous() return patch_pos_embed def forward(self, images): """ images: (B, N, C, H, W) """ B, N, _, H, W = images.shape # patchify images = rearrange(images, 'b n c h w -> (b n) c h w') tokens = self.patchify(images) tokens = rearrange(tokens, 'bn c h w -> bn (h w) c') # add pos encodings tokens = rearrange(tokens, '(b n) hw c -> b n hw c', b=B) tokens = tokens + self.interpolate_pos_encoding(tokens, W, H).unsqueeze(1) view_embeds = torch.cat([self.ref_embed, self.src_embed.repeat(1, N-1, 1)], dim=1) tokens = tokens + view_embeds.unsqueeze(2) # tokens = rearrange(tokens, '(b n) hw c -> b n hw c', b=B) # tokens = tokens + self.interpolate_pos_encoding(tokens, W, H).unsqueeze(1) # view_embeds = self.src_embed.repeat(1, N, 1) # view_embeds[:, 0:1] = torch.zeros_like(self.ref_embed) # tokens = tokens + view_embeds.unsqueeze(2) # transformer tokens = rearrange(tokens, 'b n hw c -> b (n hw) c') x = tokens for layer in self.blocks: x = layer(x) # unpatchify x = self.norm(x) x = self.unpatchify(x) x = rearrange(x, 'b (n h w) c -> b n h w c', n=N, h=H//self.patch_size, w=W//self.patch_size) x = rearrange(x, 'b n h w (p q c) -> b n (h p) (w q) c', p=self.patch_size, q=self.patch_size) out = x return out