|
from torch import nn, einsum |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange,repeat |
|
from timm.models.layers import DropPath |
|
from torch_cluster import fps |
|
import numpy as np |
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
class PositionalEmbedding(torch.nn.Module): |
|
def __init__(self, num_channels, max_positions=10000, endpoint=False): |
|
super().__init__() |
|
self.num_channels = num_channels |
|
self.max_positions = max_positions |
|
self.endpoint = endpoint |
|
|
|
def forward(self, x): |
|
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) |
|
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) |
|
freqs = (1 / self.max_positions) ** freqs |
|
x = x.ger(freqs.to(x.dtype)) |
|
x = torch.cat([x.cos(), x.sin()], dim=1) |
|
return x |
|
|
|
class CrossAttention(nn.Module): |
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
|
|
if context_dim is None: |
|
context_dim = query_dim |
|
|
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, query_dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
h = self.heads |
|
|
|
q = self.to_q(x) |
|
|
|
if context is None: |
|
context = x |
|
|
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
|
|
q, k, v = map(lambda t: rearrange( |
|
t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
|
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
|
|
|
attn = sim.softmax(dim=-1) |
|
|
|
out = torch.einsum('b i j, b j d -> b i d', attn, v) |
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
|
return self.to_out(out) |
|
|
|
|
|
class LayerScale(nn.Module): |
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
class GEGLU(nn.Module): |
|
def __init__(self, dim_in, dim_out): |
|
super().__init__() |
|
self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
|
def forward(self, x): |
|
x, gate = self.proj(x).chunk(2, dim=-1) |
|
return x * F.gelu(gate) |
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
if dim_out is None: |
|
dim_out = dim |
|
|
|
project_in = nn.Sequential( |
|
nn.Linear(dim, inner_dim), |
|
nn.GELU() |
|
) if not glu else GEGLU(dim, inner_dim) |
|
|
|
self.net = nn.Sequential( |
|
project_in, |
|
nn.Dropout(dropout), |
|
nn.Linear(inner_dim, dim_out) |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
class AdaLayerNorm(nn.Module): |
|
def __init__(self, n_embd): |
|
super().__init__() |
|
|
|
self.silu = nn.SiLU() |
|
self.linear = nn.Linear(n_embd, n_embd*2) |
|
self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) |
|
|
|
def forward(self, x, timestep): |
|
emb = self.linear(timestep) |
|
scale, shift = torch.chunk(emb, 2, dim=2) |
|
x = self.layernorm(x) * (1 + scale) + shift |
|
return x |
|
|
|
class BasicTransformerBlock(nn.Module): |
|
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): |
|
super().__init__() |
|
self.attn1 = CrossAttention( |
|
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) |
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) |
|
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, |
|
heads=n_heads, dim_head=d_head, dropout=dropout) |
|
self.norm1 = AdaLayerNorm(dim) |
|
self.norm2 = AdaLayerNorm(dim) |
|
self.norm3 = AdaLayerNorm(dim) |
|
self.checkpoint = checkpoint |
|
|
|
init_values = 0 |
|
drop_path = 0.0 |
|
|
|
|
|
self.ls1 = LayerScale( |
|
dim, init_values=init_values) if init_values else nn.Identity() |
|
self.drop_path1 = DropPath( |
|
drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.ls2 = LayerScale( |
|
dim, init_values=init_values) if init_values else nn.Identity() |
|
self.drop_path2 = DropPath( |
|
drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
self.ls3 = LayerScale( |
|
dim, init_values=init_values) if init_values else nn.Identity() |
|
self.drop_path3 = DropPath( |
|
drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def forward(self, x, t, context=None): |
|
x = self.drop_path1(self.ls1(self.attn1(self.norm1(x, t)))) + x |
|
x = self.drop_path2(self.ls2(self.attn2(self.norm2(x, t), context=context))) + x |
|
x = self.drop_path3(self.ls3(self.ff(self.norm3(x, t)))) + x |
|
return x |
|
|
|
class LatentArrayTransformer(nn.Module): |
|
""" |
|
Transformer block for image-like data. |
|
First, project the input (aka embedding) |
|
and reshape to b, t, d. |
|
Then apply standard transformer action. |
|
Finally, reshape to image |
|
""" |
|
|
|
def __init__(self, in_channels, t_channels, n_heads, d_head, |
|
depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None, |
|
block=BasicTransformerBlock): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
inner_dim = n_heads * d_head |
|
|
|
self.t_channels = t_channels |
|
|
|
self.proj_in = nn.Linear(in_channels, inner_dim, bias=False) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) |
|
for _ in range(depth)] |
|
) |
|
|
|
self.norm = nn.LayerNorm(inner_dim) |
|
|
|
if out_channels is None: |
|
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False)) |
|
else: |
|
self.num_cls = out_channels |
|
self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False)) |
|
|
|
self.context_dim = context_dim |
|
|
|
self.map_noise = PositionalEmbedding(t_channels) |
|
|
|
self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim) |
|
self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, t, cond, class_emb): |
|
|
|
t_emb = self.map_noise(t)[:, None] |
|
t_emb = F.silu(self.map_layer0(t_emb)) |
|
t_emb = F.silu(self.map_layer1(t_emb)) |
|
|
|
x = self.proj_in(x) |
|
|
|
for block in self.transformer_blocks: |
|
x = block(x, t_emb+class_emb[:,None,:], context=cond) |
|
|
|
x = self.norm(x) |
|
|
|
x = self.proj_out(x) |
|
return x |
|
|
|
class PointTransformer(nn.Module): |
|
""" |
|
Transformer block for image-like data. |
|
First, project the input (aka embedding) |
|
and reshape to b, t, d. |
|
Then apply standard transformer action. |
|
Finally, reshape to image |
|
""" |
|
|
|
def __init__(self, in_channels, t_channels, n_heads, d_head, |
|
depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None, |
|
block=BasicTransformerBlock): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
inner_dim = n_heads * d_head |
|
|
|
self.t_channels = t_channels |
|
|
|
self.proj_in = nn.Linear(in_channels, inner_dim, bias=False) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) |
|
for _ in range(depth)] |
|
) |
|
|
|
self.norm = nn.LayerNorm(inner_dim) |
|
|
|
if out_channels is None: |
|
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False)) |
|
else: |
|
self.num_cls = out_channels |
|
self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False)) |
|
|
|
self.context_dim = context_dim |
|
|
|
self.map_noise = PositionalEmbedding(t_channels) |
|
|
|
self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim) |
|
self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, t, cond=None): |
|
|
|
t_emb = self.map_noise(t)[:, None] |
|
t_emb = F.silu(self.map_layer0(t_emb)) |
|
t_emb = F.silu(self.map_layer1(t_emb)) |
|
|
|
x = self.proj_in(x) |
|
|
|
for block in self.transformer_blocks: |
|
x = block(x, t_emb, context=cond) |
|
|
|
x = self.norm(x) |
|
|
|
x = self.proj_out(x) |
|
return x |
|
def exists(val): |
|
return val is not None |
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
def cache_fn(f): |
|
cache = None |
|
@wraps(f) |
|
def cached_fn(*args, _cache = True, **kwargs): |
|
if not _cache: |
|
return f(*args, **kwargs) |
|
nonlocal cache |
|
if cache is not None: |
|
return cache |
|
cache = f(*args, **kwargs) |
|
return cache |
|
return cached_fn |
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn, context_dim = None): |
|
super().__init__() |
|
self.fn = fn |
|
self.norm = nn.LayerNorm(dim) |
|
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None |
|
|
|
def forward(self, x, **kwargs): |
|
x = self.norm(x) |
|
|
|
if exists(self.norm_context): |
|
context = kwargs['context'] |
|
normed_context = self.norm_context(context) |
|
kwargs.update(context = normed_context) |
|
|
|
return self.fn(x, **kwargs) |
|
|
|
class Attention(nn.Module): |
|
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
context_dim = default(context_dim, query_dim) |
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias = False) |
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) |
|
self.to_out = nn.Linear(inner_dim, query_dim) |
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() |
|
|
|
def forward(self, x, context = None, mask = None): |
|
h = self.heads |
|
|
|
q = self.to_q(x) |
|
context = default(context, x) |
|
k, v = self.to_kv(context).chunk(2, dim = -1) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) |
|
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
|
|
|
if exists(mask): |
|
mask = rearrange(mask, 'b ... -> b (...)') |
|
max_neg_value = -torch.finfo(sim.dtype).max |
|
mask = repeat(mask, 'b j -> (b h) () j', h = h) |
|
sim.masked_fill_(~mask, max_neg_value) |
|
|
|
|
|
attn = sim.softmax(dim = -1) |
|
|
|
out = einsum('b i j, b j d -> b i d', attn, v) |
|
out = rearrange(out, '(b h) n d -> b n (h d)', h = h) |
|
return self.drop_path(self.to_out(out)) |
|
|
|
|
|
class PointEmbed(nn.Module): |
|
def __init__(self, hidden_dim=48, dim=128): |
|
super().__init__() |
|
|
|
assert hidden_dim % 6 == 0 |
|
|
|
self.embedding_dim = hidden_dim |
|
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi |
|
e = torch.stack([ |
|
torch.cat([e, torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), e, |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6), e]), |
|
]) |
|
self.register_buffer('basis', e) |
|
|
|
self.mlp = nn.Linear(self.embedding_dim + 3, dim) |
|
|
|
@staticmethod |
|
def embed(input, basis): |
|
projections = torch.einsum( |
|
'bnd,de->bne', input, basis) |
|
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) |
|
return embeddings |
|
|
|
def forward(self, input): |
|
|
|
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) |
|
return embed |
|
|
|
|
|
class PointEncoder(nn.Module): |
|
def __init__(self, |
|
dim=512, |
|
num_inputs = 2048, |
|
num_latents = 512, |
|
latent_dim = 512): |
|
super().__init__() |
|
|
|
self.num_inputs = num_inputs |
|
self.num_latents = num_latents |
|
|
|
self.cross_attend_blocks = nn.ModuleList([ |
|
PreNorm(dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim), |
|
PreNorm(dim, FeedForward(dim)) |
|
]) |
|
|
|
self.point_embed = PointEmbed(dim=dim) |
|
self.proj=nn.Linear(dim,latent_dim) |
|
def encode(self, pc): |
|
|
|
B, N, D = pc.shape |
|
assert N == self.num_inputs |
|
|
|
|
|
flattened = pc.view(B * N, D) |
|
|
|
batch = torch.arange(B).to(pc.device) |
|
batch = torch.repeat_interleave(batch, N) |
|
|
|
pos = flattened |
|
|
|
ratio = 1.0 * self.num_latents / self.num_inputs |
|
|
|
idx = fps(pos, batch, ratio=ratio) |
|
|
|
sampled_pc = pos[idx] |
|
sampled_pc = sampled_pc.view(B, -1, 3) |
|
|
|
|
|
sampled_pc_embeddings = self.point_embed(sampled_pc) |
|
|
|
pc_embeddings = self.point_embed(pc) |
|
|
|
cross_attn, cross_ff = self.cross_attend_blocks |
|
|
|
x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings |
|
x = cross_ff(x) + x |
|
|
|
return self.proj(x) |