LASA / models /modules /point_transformer.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
raw
history blame
14.3 kB
from torch import nn, einsum
import torch
import torch.nn.functional as F
from einops import rearrange,repeat
from timm.models.layers import DropPath
from torch_cluster import fps
import numpy as np
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class PositionalEmbedding(torch.nn.Module):
def __init__(self, num_channels, max_positions=10000, endpoint=False):
super().__init__()
self.num_channels = num_channels
self.max_positions = max_positions
self.endpoint = endpoint
def forward(self, x):
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
freqs = (1 / self.max_positions) ** freqs
x = x.ger(freqs.to(x.dtype))
x = torch.cat([x.cos(), x.sin()], dim=1)
return x
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
if context_dim is None:
context_dim = query_dim
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
if context is None:
context = x
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
if dim_out is None:
dim_out = dim
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class AdaLayerNorm(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(n_embd, n_embd*2)
self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)
def forward(self, x, timestep):
emb = self.linear(timestep)
scale, shift = torch.chunk(emb, 2, dim=2)
x = self.layernorm(x) * (1 + scale) + shift
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = AdaLayerNorm(dim)
self.norm2 = AdaLayerNorm(dim)
self.norm3 = AdaLayerNorm(dim)
self.checkpoint = checkpoint
init_values = 0
drop_path = 0.0
self.ls1 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.ls3 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path3 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, t, context=None):
x = self.drop_path1(self.ls1(self.attn1(self.norm1(x, t)))) + x
x = self.drop_path2(self.ls2(self.attn2(self.norm2(x, t), context=context))) + x
x = self.drop_path3(self.ls3(self.ff(self.norm3(x, t)))) + x
return x
class LatentArrayTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, t_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
block=BasicTransformerBlock):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.t_channels = t_channels
self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
self.transformer_blocks = nn.ModuleList(
[block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for _ in range(depth)]
)
self.norm = nn.LayerNorm(inner_dim)
if out_channels is None:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
else:
self.num_cls = out_channels
self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
self.context_dim = context_dim
self.map_noise = PositionalEmbedding(t_channels)
self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
# ###
# self.pos_emb = nn.Embedding(512, inner_dim)
# ###
def forward(self, x, t, cond, class_emb):
t_emb = self.map_noise(t)[:, None]
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))
x = self.proj_in(x)
#print(class_emb.shape,t_emb.shape)
for block in self.transformer_blocks:
x = block(x, t_emb+class_emb[:,None,:], context=cond)
x = self.norm(x)
x = self.proj_out(x)
return x
class PointTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, t_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
block=BasicTransformerBlock):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.t_channels = t_channels
self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
self.transformer_blocks = nn.ModuleList(
[block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for _ in range(depth)]
)
self.norm = nn.LayerNorm(inner_dim)
if out_channels is None:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
else:
self.num_cls = out_channels
self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
self.context_dim = context_dim
self.map_noise = PositionalEmbedding(t_channels)
self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
# ###
# self.pos_emb = nn.Embedding(512, inner_dim)
# ###
def forward(self, x, t, cond=None):
t_emb = self.map_noise(t)[:, None]
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))
x = self.proj_in(x)
for block in self.transformer_blocks:
x = block(x, t_emb, context=cond)
x = self.norm(x)
x = self.proj_out(x)
return x
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cache_fn(f):
cache = None
@wraps(f)
def cached_fn(*args, _cache = True, **kwargs):
if not _cache:
return f(*args, **kwargs)
nonlocal cache
if cache is not None:
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
class PreNorm(nn.Module):
def __init__(self, dim, fn, context_dim = None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
def forward(self, x, **kwargs):
x = self.norm(x)
if exists(self.norm_context):
context = kwargs['context']
normed_context = self.norm_context(context)
kwargs.update(context = normed_context)
return self.fn(x, **kwargs)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, query_dim)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, context = None, mask = None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h = h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.drop_path(self.to_out(out))
class PointEmbed(nn.Module):
def __init__(self, hidden_dim=48, dim=128):
super().__init__()
assert hidden_dim % 6 == 0
self.embedding_dim = hidden_dim
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
e = torch.stack([
torch.cat([e, torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6), e,
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6), e]),
])
self.register_buffer('basis', e) # 3 x 16
self.mlp = nn.Linear(self.embedding_dim + 3, dim)
@staticmethod
def embed(input, basis):
projections = torch.einsum(
'bnd,de->bne', input, basis)
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
return embeddings
def forward(self, input):
# input: B x N x 3
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
return embed
class PointEncoder(nn.Module):
def __init__(self,
dim=512,
num_inputs = 2048,
num_latents = 512,
latent_dim = 512):
super().__init__()
self.num_inputs = num_inputs
self.num_latents = num_latents
self.cross_attend_blocks = nn.ModuleList([
PreNorm(dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim),
PreNorm(dim, FeedForward(dim))
])
self.point_embed = PointEmbed(dim=dim)
self.proj=nn.Linear(dim,latent_dim)
def encode(self, pc):
# pc: B x N x 3
B, N, D = pc.shape
assert N == self.num_inputs
###### fps
flattened = pc.view(B * N, D)
batch = torch.arange(B).to(pc.device)
batch = torch.repeat_interleave(batch, N)
pos = flattened
ratio = 1.0 * self.num_latents / self.num_inputs
idx = fps(pos, batch, ratio=ratio)
sampled_pc = pos[idx]
sampled_pc = sampled_pc.view(B, -1, 3)
######
sampled_pc_embeddings = self.point_embed(sampled_pc)
pc_embeddings = self.point_embed(pc)
cross_attn, cross_ff = self.cross_attend_blocks
x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings
x = cross_ff(x) + x
return self.proj(x)