Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import numpy as np | |
import math | |
from einops import rearrange | |
from vit.vision_transformer import MemEffAttention, Attention | |
# from xformers.triton import FusedLayerNorm as LayerNorm | |
from torch.nn import LayerNorm | |
from xformers.components.feedforward import fused_mlp | |
# from xformers.components.feedforward import mlp | |
from xformers.components.activations import build_activation, Activation | |
class PositionalEncoding(nn.Module): | |
def __init__(self, num_octaves=8, start_octave=0): | |
super().__init__() | |
self.num_octaves = num_octaves | |
self.start_octave = start_octave | |
def forward(self, coords, rays=None): | |
embed_fns = [] | |
batch_size, num_points, dim = coords.shape | |
octaves = torch.arange(self.start_octave, | |
self.start_octave + self.num_octaves) | |
octaves = octaves.float().to(coords) | |
multipliers = 2**octaves * math.pi | |
coords = coords.unsqueeze(-1) | |
while len(multipliers.shape) < len(coords.shape): | |
multipliers = multipliers.unsqueeze(0) | |
scaled_coords = coords * multipliers | |
sines = torch.sin(scaled_coords).reshape(batch_size, num_points, | |
dim * self.num_octaves) | |
cosines = torch.cos(scaled_coords).reshape(batch_size, num_points, | |
dim * self.num_octaves) | |
result = torch.cat((sines, cosines), -1) | |
return result | |
class RayEncoder(nn.Module): | |
def __init__(self, | |
pos_octaves=8, | |
pos_start_octave=0, | |
ray_octaves=4, | |
ray_start_octave=0): | |
super().__init__() | |
self.pos_encoding = PositionalEncoding(num_octaves=pos_octaves, | |
start_octave=pos_start_octave) | |
self.ray_encoding = PositionalEncoding(num_octaves=ray_octaves, | |
start_octave=ray_start_octave) | |
def forward(self, pos, rays): | |
if len(rays.shape) == 4: | |
batchsize, height, width, dims = rays.shape | |
pos_enc = self.pos_encoding(pos.unsqueeze(1)) | |
pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1) | |
pos_enc = pos_enc.repeat(1, 1, height, width) | |
rays = rays.flatten(1, 2) | |
ray_enc = self.ray_encoding(rays) | |
ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1]) | |
ray_enc = ray_enc.permute((0, 3, 1, 2)) | |
x = torch.cat((pos_enc, ray_enc), 1) | |
else: | |
pos_enc = self.pos_encoding(pos) | |
ray_enc = self.ray_encoding(rays) | |
x = torch.cat((pos_enc, ray_enc), -1) | |
return x | |
# Transformer implementation based on ViT | |
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = LayerNorm(dim) | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
return self.fn(self.norm(x), **kwargs) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout=0.): | |
super().__init__() | |
self.net = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(hidden_dim, | |
dim), nn.Dropout(dropout)) | |
def forward(self, x): | |
return self.net(x) | |
# class Attention(nn.Module): | |
# def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): | |
# super().__init__() | |
# inner_dim = dim_head * heads | |
# project_out = not (heads == 1 and dim_head == dim) | |
# self.heads = heads | |
# self.scale = dim_head ** -0.5 | |
# self.attend = nn.Softmax(dim=-1) | |
# if selfatt: | |
# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
# else: | |
# self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
# self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False) | |
# self.to_out = nn.Sequential( | |
# nn.Linear(inner_dim, dim), | |
# nn.Dropout(dropout) | |
# ) if project_out else nn.Identity() | |
# def forward(self, x, z=None): | |
# if z is None: | |
# qkv = self.to_qkv(x).chunk(3, dim=-1) | |
# else: | |
# q = self.to_q(x) | |
# k, v = self.to_kv(z).chunk(2, dim=-1) | |
# qkv = (q, k, v) | |
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) | |
# dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale | |
# attn = self.attend(dots) | |
# out = torch.matmul(attn, v) | |
# out = rearrange(out, 'b h n d -> b n (h d)') | |
# return self.to_out(out) | |
class Transformer(nn.Module): | |
def __init__(self, | |
dim, | |
depth, | |
heads, | |
mlp_dim, | |
dropout=0., | |
selfatt=True, | |
kv_dim=None, | |
no_flash_op=False,): | |
super().__init__() | |
# if no_flash_op: | |
# attn_cls = Attention # raw torch attention | |
# else: | |
attn_cls = MemEffAttention | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList([ | |
PreNorm(dim, | |
attn_cls( | |
dim, | |
num_heads=heads, | |
qkv_bias=True, | |
qk_norm=True, # as in vit-22B | |
no_flash_op=no_flash_op, | |
)), | |
PreNorm( | |
dim, | |
fused_mlp.FusedMLP(dim, | |
# mlp.MLP(dim, | |
hidden_layer_multiplier=mlp_dim // | |
dim, | |
dropout=dropout, | |
activation=Activation.GeLU)) | |
])) | |
def forward(self, x): | |
for attn, ff in self.layers: # type: ignore | |
x = attn(x) + x | |
x = ff(x) + x | |
return x | |