|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import sys |
|
import os |
|
from external.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule |
|
from .utils import zero_module |
|
from .Positional_Embedding import PositionalEmbedding |
|
|
|
class Pointnet2Encoder(nn.Module): |
|
def __init__(self,input_feature_dim=0,npoints=[2048,1024,512,256],radius=[0.2,0.4,0.6,1.2],nsample=[64,32,16,8]): |
|
super().__init__() |
|
self.sa1 = PointnetSAModuleVotes( |
|
npoint=npoints[0], |
|
radius=radius[0], |
|
nsample=nsample[0], |
|
mlp=[input_feature_dim, 64, 64, 128], |
|
use_xyz=True, |
|
normalize_xyz=True |
|
) |
|
|
|
self.sa2 = PointnetSAModuleVotes( |
|
npoint=npoints[1], |
|
radius=radius[1], |
|
nsample=nsample[1], |
|
mlp=[128, 128, 128, 256], |
|
use_xyz=True, |
|
normalize_xyz=True |
|
) |
|
|
|
self.sa3 = PointnetSAModuleVotes( |
|
npoint=npoints[2], |
|
radius=radius[2], |
|
nsample=nsample[2], |
|
mlp=[256, 256, 256, 512], |
|
use_xyz=True, |
|
normalize_xyz=True |
|
) |
|
|
|
self.sa4 = PointnetSAModuleVotes( |
|
npoint=npoints[3], |
|
radius=radius[3], |
|
nsample=nsample[3], |
|
mlp=[512, 512, 512, 512], |
|
use_xyz=True, |
|
normalize_xyz=True |
|
) |
|
def _break_up_pc(self, pc): |
|
xyz = pc[..., 0:3].contiguous() |
|
features = ( |
|
pc[..., 3:].transpose(1, 2).contiguous() |
|
if pc.size(-1) > 3 else None |
|
) |
|
|
|
return xyz, features |
|
def forward(self,pointcloud,end_points=None): |
|
if not end_points: end_points = {} |
|
batch_size = pointcloud.shape[0] |
|
|
|
xyz, features = self._break_up_pc(pointcloud) |
|
|
|
end_points['org_xyz'] = xyz |
|
|
|
xyz1, features1, _ = self.sa1(xyz, features) |
|
end_points['sa1_xyz'] = xyz1 |
|
end_points['sa1_features'] = features1 |
|
|
|
xyz2, features2, _ = self.sa2(xyz1, features1) |
|
end_points['sa2_xyz'] = xyz2 |
|
end_points['sa2_features'] = features2 |
|
|
|
xyz3, features3, _ = self.sa3(xyz2, features2) |
|
end_points['sa3_xyz'] = xyz3 |
|
end_points['sa3_features'] = features3 |
|
|
|
xyz4, features4, _ = self.sa4(xyz3, features3) |
|
end_points['sa4_xyz'] = xyz4 |
|
end_points['sa4_features'] = features4 |
|
|
|
return end_points |
|
|
|
|
|
|
|
class PointUNet(nn.Module): |
|
r""" |
|
Backbone network for point cloud feature learning. |
|
Based on Pointnet++ single-scale grouping network. |
|
|
|
Parameters |
|
---------- |
|
input_feature_dim: int |
|
Number of input channels in the feature descriptor for each point. |
|
e.g. 3 for RGB. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.noisy_encoder=Pointnet2Encoder() |
|
self.cond_encoder=Pointnet2Encoder() |
|
self.fp1_cross = PointnetFPModule(mlp=[512 + 512, 512, 512]) |
|
self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512]) |
|
|
|
self.fp2_cross = PointnetFPModule(mlp=[512 + 512, 512, 256]) |
|
self.fp2 = PointnetFPModule(mlp=[256 + 256, 512, 256]) |
|
|
|
self.fp3_cross= PointnetFPModule(mlp=[256 + 256, 256, 128]) |
|
self.fp3 = PointnetFPModule(mlp=[128 + 128, 256, 128]) |
|
|
|
self.fp4_cross=PointnetFPModule(mlp=[128+128, 128, 128]) |
|
self.fp4 = PointnetFPModule(mlp=[128, 128, 128]) |
|
|
|
|
|
self.output_layer=nn.Sequential( |
|
nn.LayerNorm(128), |
|
zero_module(nn.Linear(in_features=128,out_features=3,bias=False)) |
|
) |
|
self.t_emb_layer = PositionalEmbedding(256) |
|
self.map_layer0 = nn.Linear(in_features=256, out_features=512) |
|
self.map_layer1 = nn.Linear(in_features=512, out_features=512) |
|
|
|
def forward(self, noise_points, t,cond_points): |
|
r""" |
|
Forward pass of the network |
|
|
|
Parameters |
|
---------- |
|
pointcloud: Variable(torch.cuda.FloatTensor) |
|
(B, N, 3 + input_feature_dim) tensor |
|
Point cloud to run predicts on |
|
Each point in the point-cloud MUST |
|
be formated as (x, y, z, features...) |
|
|
|
Returns |
|
---------- |
|
end_points: {XXX_xyz, XXX_features, XXX_inds} |
|
XXX_xyz: float32 Tensor of shape (B,K,3) |
|
XXX_features: float32 Tensor of shape (B,K,D) |
|
XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1] |
|
""" |
|
t_emb = self.t_emb_layer(t) |
|
t_emb = F.silu(self.map_layer0(t_emb)) |
|
t_emb = F.silu(self.map_layer1(t_emb)) |
|
t_emb = t_emb[:, :, None] |
|
noise_end_points=self.noisy_encoder(noise_points) |
|
cond=self.cond_encoder(cond_points) |
|
|
|
features = self.fp1_cross(noise_end_points['sa4_xyz'],cond['sa4_xyz'],noise_end_points['sa4_features']+t_emb, |
|
cond['sa4_features']) |
|
features = self.fp1(noise_end_points['sa3_xyz'], noise_end_points['sa4_xyz'], noise_end_points['sa3_features'], |
|
features) |
|
features = self.fp2_cross(noise_end_points['sa3_xyz'],cond['sa3_xyz'],features, |
|
cond["sa3_features"]) |
|
features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'], |
|
features) |
|
features = self.fp3_cross(noise_end_points['sa2_xyz'],cond['sa2_xyz'],features, |
|
cond['sa2_features']) |
|
features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features) |
|
features = self.fp4_cross(noise_end_points['sa1_xyz'],cond['sa1_xyz'],features, |
|
cond['sa1_features']) |
|
features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features) |
|
features=features.transpose(1,2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_points=self.output_layer(features) |
|
|
|
return output_points |
|
|
|
|
|
if __name__ == '__main__': |
|
net=PointUNet().cuda().float() |
|
net=net.eval() |
|
noise_points=torch.randn(16,4096,3).cuda().float() |
|
cond_points=torch.randn(16,4096,3).cuda().float() |
|
t=torch.randn(16).cuda().float() |
|
cond_encoder=Pointnet2Encoder().cuda().float() |
|
|
|
out = net(noise_points,cond_points) |
|
print(out.shape) |