LASA / models /modules /pointnet2_backbone.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
raw
history blame
7.82 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
import os
from external.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule
from .utils import zero_module
from .Positional_Embedding import PositionalEmbedding
class Pointnet2Encoder(nn.Module):
def __init__(self,input_feature_dim=0,npoints=[2048,1024,512,256],radius=[0.2,0.4,0.6,1.2],nsample=[64,32,16,8]):
super().__init__()
self.sa1 = PointnetSAModuleVotes(
npoint=npoints[0],
radius=radius[0],
nsample=nsample[0],
mlp=[input_feature_dim, 64, 64, 128],
use_xyz=True,
normalize_xyz=True
)
self.sa2 = PointnetSAModuleVotes(
npoint=npoints[1],
radius=radius[1],
nsample=nsample[1],
mlp=[128, 128, 128, 256],
use_xyz=True,
normalize_xyz=True
)
self.sa3 = PointnetSAModuleVotes(
npoint=npoints[2],
radius=radius[2],
nsample=nsample[2],
mlp=[256, 256, 256, 512],
use_xyz=True,
normalize_xyz=True
)
self.sa4 = PointnetSAModuleVotes(
npoint=npoints[3],
radius=radius[3],
nsample=nsample[3],
mlp=[512, 512, 512, 512],
use_xyz=True,
normalize_xyz=True
)
def _break_up_pc(self, pc):
xyz = pc[..., 0:3].contiguous()
features = (
pc[..., 3:].transpose(1, 2).contiguous()
if pc.size(-1) > 3 else None
)
return xyz, features
def forward(self,pointcloud,end_points=None):
if not end_points: end_points = {}
batch_size = pointcloud.shape[0]
xyz, features = self._break_up_pc(pointcloud)
end_points['org_xyz'] = xyz
# --------- 4 SET ABSTRACTION LAYERS ---------
xyz1, features1, _ = self.sa1(xyz, features)
end_points['sa1_xyz'] = xyz1
end_points['sa1_features'] = features1
xyz2, features2, _ = self.sa2(xyz1, features1) # this fps_inds is just 0,1,...,1023
end_points['sa2_xyz'] = xyz2
end_points['sa2_features'] = features2
xyz3, features3, _ = self.sa3(xyz2, features2) # this fps_inds is just 0,1,...,511
end_points['sa3_xyz'] = xyz3
end_points['sa3_features'] = features3
#print(xyz3.shape,features3.shape)
xyz4, features4, _ = self.sa4(xyz3, features3) # this fps_inds is just 0,1,...,255
end_points['sa4_xyz'] = xyz4
end_points['sa4_features'] = features4
#print(xyz4.shape,features4.shape)
return end_points
class PointUNet(nn.Module):
r"""
Backbone network for point cloud feature learning.
Based on Pointnet++ single-scale grouping network.
Parameters
----------
input_feature_dim: int
Number of input channels in the feature descriptor for each point.
e.g. 3 for RGB.
"""
def __init__(self):
super().__init__()
self.noisy_encoder=Pointnet2Encoder()
self.cond_encoder=Pointnet2Encoder()
self.fp1_cross = PointnetFPModule(mlp=[512 + 512, 512, 512])
self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
#self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
self.fp2_cross = PointnetFPModule(mlp=[512 + 512, 512, 256])
self.fp2 = PointnetFPModule(mlp=[256 + 256, 512, 256])
#self.fp2=PointnetFPModule(mlp=[512 + 256, 512, 256])
self.fp3_cross= PointnetFPModule(mlp=[256 + 256, 256, 128])
self.fp3 = PointnetFPModule(mlp=[128 + 128, 256, 128])
#self.fp3 = PointnetFPModule(mlp=[256 + 128, 256, 128])
self.fp4_cross=PointnetFPModule(mlp=[128+128, 128, 128])
self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
#self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
self.output_layer=nn.Sequential(
nn.LayerNorm(128),
zero_module(nn.Linear(in_features=128,out_features=3,bias=False))
)
self.t_emb_layer = PositionalEmbedding(256)
self.map_layer0 = nn.Linear(in_features=256, out_features=512)
self.map_layer1 = nn.Linear(in_features=512, out_features=512)
def forward(self, noise_points, t,cond_points):
r"""
Forward pass of the network
Parameters
----------
pointcloud: Variable(torch.cuda.FloatTensor)
(B, N, 3 + input_feature_dim) tensor
Point cloud to run predicts on
Each point in the point-cloud MUST
be formated as (x, y, z, features...)
Returns
----------
end_points: {XXX_xyz, XXX_features, XXX_inds}
XXX_xyz: float32 Tensor of shape (B,K,3)
XXX_features: float32 Tensor of shape (B,K,D)
XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1]
"""
t_emb = self.t_emb_layer(t)
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))#B,512
t_emb = t_emb[:, :, None] #B,512,K
noise_end_points=self.noisy_encoder(noise_points)
cond=self.cond_encoder(cond_points)
# --------- 2 FEATURE UPSAMPLING LAYERS --------
features = self.fp1_cross(noise_end_points['sa4_xyz'],cond['sa4_xyz'],noise_end_points['sa4_features']+t_emb,
cond['sa4_features'])
features = self.fp1(noise_end_points['sa3_xyz'], noise_end_points['sa4_xyz'], noise_end_points['sa3_features'],
features)
features = self.fp2_cross(noise_end_points['sa3_xyz'],cond['sa3_xyz'],features,
cond["sa3_features"])
features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
features)
features = self.fp3_cross(noise_end_points['sa2_xyz'],cond['sa2_xyz'],features,
cond['sa2_features'])
features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
features = self.fp4_cross(noise_end_points['sa1_xyz'],cond['sa1_xyz'],features,
cond['sa1_features'])
features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
features=features.transpose(1,2)
# features = self.fp1_cross(noise_end_points['sa4_xyz'], cond_end_points['sa4_xyz'],
# noise_end_points['sa4_features']+t_emb, cond_end_points['sa4_features'])
# features = self.fp1(noise_end_points['sa3_xyz'].clone(), noise_end_points['sa4_xyz'].clone(), noise_end_points['sa3_features'],
# features)
# features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
# features)
# features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
# features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
# features = features.transpose(1,2)
output_points=self.output_layer(features)
return output_points
if __name__ == '__main__':
net=PointUNet().cuda().float()
net=net.eval()
noise_points=torch.randn(16,4096,3).cuda().float()
cond_points=torch.randn(16,4096,3).cuda().float()
t=torch.randn(16).cuda().float()
cond_encoder=Pointnet2Encoder().cuda().float()
out = net(noise_points,cond_points)
print(out.shape)