yslan's picture
init
7f51798
# https://gist.github.com/lucidrains/5193d38d1d889681dd42feb847f1f6da
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit_3d.py
import torch
from torch import nn
from pdb import set_trace as st
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from .vit_with_mask import Transformer
# helpers
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
# class PreNorm(nn.Module):
# def __init__(self, dim, fn):
# super().__init__()
# self.norm = nn.LayerNorm(dim)
# self.fn = fn
# def forward(self, x, **kwargs):
# return self.fn(self.norm(x), **kwargs)
# class FeedForward(nn.Module):
# def __init__(self, dim, hidden_dim, dropout=0.):
# super().__init__()
# self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(),
# nn.Dropout(dropout),
# nn.Linear(hidden_dim,
# dim), nn.Dropout(dropout))
# def forward(self, x):
# return self.net(x)
# class Attention(nn.Module):
# def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
# super().__init__()
# inner_dim = dim_head * heads
# project_out = not (heads == 1 and dim_head == dim)
# self.heads = heads
# self.scale = dim_head**-0.5
# self.attend = nn.Softmax(dim=-1)
# self.dropout = nn.Dropout(dropout)
# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
# self.to_out = nn.Sequential(
# nn.Linear(inner_dim, dim),
# nn.Dropout(dropout)) if project_out else nn.Identity()
# def forward(self, x):
# qkv = self.to_qkv(x).chunk(3, dim=-1)
# q, k, v = map(
# lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
# dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attn = self.attend(dots)
# attn = self.dropout(attn)
# out = torch.matmul(attn, v)
# out = rearrange(out, 'b h n d -> b n (h d)')
# return self.to_out(out)
# class Transformer(nn.Module):
# def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
# super().__init__()
# self.layers = nn.ModuleList([])
# for _ in range(depth):
# self.layers.append(
# nn.ModuleList([
# PreNorm(
# dim,
# Attention(dim,
# heads=heads,
# dim_head=dim_head,
# dropout=dropout)),
# PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
# ]))
# def forward(self, x):
# for attn, ff in self.layers:
# x = attn(x) + x
# x = ff(x) + x
# return x
# https://gist.github.com/lucidrains/213d2be85d67d71147d807737460baf4
class ViTVoxel(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 3
patch_dim = channels * patch_size ** 3
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, num_classes),
nn.Dropout(dropout)
)
def forward(self, img, mask = None):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)', p1 = p, p2 = p, p3 = p)
x = self.patch_to_embedding(x)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)
class ViTTriplane(nn.Module):
def __init__(self, *, image_size, triplane_size, image_patch_size, triplane_patch_size, num_classes, dim, depth, heads, mlp_dim, patch_embed=False, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % image_patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // image_patch_size) ** 2 * triplane_size # 14*14*3
# patch_dim = channels * image_patch_size ** 3
self.patch_size = image_patch_size
self.triplane_patch_size = triplane_patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_embed = patch_embed
# if self.patch_embed:
patch_dim = channels * image_patch_size ** 2 * triplane_patch_size # 1
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.to_cls_token = nn.Identity()
# self.mlp_head = nn.Sequential(
# nn.LayerNorm(dim),
# nn.Linear(dim, mlp_dim),
# nn.GELU(),
# nn.Dropout(dropout),
# nn.Linear(mlp_dim, num_classes),
# nn.Dropout(dropout)
# )
def forward(self, triplane, mask = None):
p = self.patch_size
p_3d = self.triplane_patch_size
x = rearrange(triplane, 'b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)', p1 = p, p2 = p, p3 = p_3d)
# if self.patch_embed:
x = self.patch_to_embedding(x) # B 14*14*4 768
cls_tokens = self.cls_token.expand(triplane.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding
x = self.dropout(x)
x = self.transformer(x, mask)
return x[:, 1:]
# x = self.to_cls_token(x[:, 0])
# return self.mlp_head(x)