import torch.nn as nn import torch import numpy as np class PointEmbed(nn.Module): def __init__(self, hidden_dim=48): super().__init__() assert hidden_dim % 6 == 0 self.embedding_dim = hidden_dim e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi e = torch.stack([ torch.cat([e, torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6)]), torch.cat([torch.zeros(self.embedding_dim // 6), e, torch.zeros(self.embedding_dim // 6)]), torch.cat([torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6), e]), ]) self.register_buffer('basis', e) # 3 x 24 @staticmethod def embed(input, basis): projections = torch.einsum( 'bnd,de->bne', input, basis) # N,24 embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) return embeddings def forward(self, input): # input: B x N x 3 embed = self.embed(input, self.basis) return embed