yslan's picture
init
7f51798
raw
history blame
6.4 kB
import torch
import torch.nn as nn
import numpy as np
import math
from einops import rearrange
from vit.vision_transformer import MemEffAttention, Attention
# from xformers.triton import FusedLayerNorm as LayerNorm
from torch.nn import LayerNorm
from xformers.components.feedforward import fused_mlp
# from xformers.components.feedforward import mlp
from xformers.components.activations import build_activation, Activation
class PositionalEncoding(nn.Module):
def __init__(self, num_octaves=8, start_octave=0):
super().__init__()
self.num_octaves = num_octaves
self.start_octave = start_octave
def forward(self, coords, rays=None):
embed_fns = []
batch_size, num_points, dim = coords.shape
octaves = torch.arange(self.start_octave,
self.start_octave + self.num_octaves)
octaves = octaves.float().to(coords)
multipliers = 2**octaves * math.pi
coords = coords.unsqueeze(-1)
while len(multipliers.shape) < len(coords.shape):
multipliers = multipliers.unsqueeze(0)
scaled_coords = coords * multipliers
sines = torch.sin(scaled_coords).reshape(batch_size, num_points,
dim * self.num_octaves)
cosines = torch.cos(scaled_coords).reshape(batch_size, num_points,
dim * self.num_octaves)
result = torch.cat((sines, cosines), -1)
return result
class RayEncoder(nn.Module):
def __init__(self,
pos_octaves=8,
pos_start_octave=0,
ray_octaves=4,
ray_start_octave=0):
super().__init__()
self.pos_encoding = PositionalEncoding(num_octaves=pos_octaves,
start_octave=pos_start_octave)
self.ray_encoding = PositionalEncoding(num_octaves=ray_octaves,
start_octave=ray_start_octave)
def forward(self, pos, rays):
if len(rays.shape) == 4:
batchsize, height, width, dims = rays.shape
pos_enc = self.pos_encoding(pos.unsqueeze(1))
pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1)
pos_enc = pos_enc.repeat(1, 1, height, width)
rays = rays.flatten(1, 2)
ray_enc = self.ray_encoding(rays)
ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1])
ray_enc = ray_enc.permute((0, 3, 1, 2))
x = torch.cat((pos_enc, ray_enc), 1)
else:
pos_enc = self.pos_encoding(pos)
ray_enc = self.ray_encoding(rays)
x = torch.cat((pos_enc, ray_enc), -1)
return x
# Transformer implementation based on ViT
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim,
dim), nn.Dropout(dropout))
def forward(self, x):
return self.net(x)
# class Attention(nn.Module):
# def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
# super().__init__()
# inner_dim = dim_head * heads
# project_out = not (heads == 1 and dim_head == dim)
# self.heads = heads
# self.scale = dim_head ** -0.5
# self.attend = nn.Softmax(dim=-1)
# if selfatt:
# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
# else:
# self.to_q = nn.Linear(dim, inner_dim, bias=False)
# self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False)
# self.to_out = nn.Sequential(
# nn.Linear(inner_dim, dim),
# nn.Dropout(dropout)
# ) if project_out else nn.Identity()
# def forward(self, x, z=None):
# if z is None:
# qkv = self.to_qkv(x).chunk(3, dim=-1)
# else:
# q = self.to_q(x)
# k, v = self.to_kv(z).chunk(2, dim=-1)
# qkv = (q, k, v)
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
# dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
# attn = self.attend(dots)
# out = torch.matmul(attn, v)
# out = rearrange(out, 'b h n d -> b n (h d)')
# return self.to_out(out)
class Transformer(nn.Module):
def __init__(self,
dim,
depth,
heads,
mlp_dim,
dropout=0.,
selfatt=True,
kv_dim=None,
no_flash_op=False,):
super().__init__()
# if no_flash_op:
# attn_cls = Attention # raw torch attention
# else:
attn_cls = MemEffAttention
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PreNorm(dim,
attn_cls(
dim,
num_heads=heads,
qkv_bias=True,
qk_norm=True, # as in vit-22B
no_flash_op=no_flash_op,
)),
PreNorm(
dim,
fused_mlp.FusedMLP(dim,
# mlp.MLP(dim,
hidden_layer_multiplier=mlp_dim //
dim,
dropout=dropout,
activation=Activation.GeLU))
]))
def forward(self, x):
for attn, ff in self.layers: # type: ignore
x = attn(x) + x
x = ff(x) + x
return x