ziqima's picture
initial commit
4893ce0
raw
history blame
3.4 kB
import torch
from torch.autograd import Function
from pointops._C import knn_query_cuda, random_ball_query_cuda, ball_query_cuda
class KNNQuery(Function):
@staticmethod
def forward(ctx, nsample, xyz, offset, new_xyz=None, new_offset=None):
"""
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b)
output: idx: (m, nsample) -1 is placeholder, dist2: (m, nsample)
"""
if new_xyz is None or new_offset is None:
new_xyz = xyz
new_offset = offset
assert xyz.is_contiguous() and new_xyz.is_contiguous()
m = new_xyz.shape[0]
idx = torch.cuda.IntTensor(m, nsample).zero_()
dist2 = torch.cuda.FloatTensor(m, nsample).zero_()
knn_query_cuda(
m, nsample, xyz, new_xyz, offset.int(), new_offset.int(), idx, dist2
)
return idx, torch.sqrt(dist2)
class RandomBallQuery(Function):
"""Random Ball Query.
Find nearby points in spherical space.
"""
@staticmethod
def forward(
ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None
):
"""
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b)
output: idx: (m, nsample), dist2: (m, nsample)
"""
if new_xyz is None or new_offset is None:
new_xyz = xyz
new_offset = offset
assert xyz.is_contiguous() and new_xyz.is_contiguous()
assert min_radius < max_radius
m = new_xyz.shape[0]
order = []
for k in range(offset.shape[0]):
s_k, e_k = (0, offset[0]) if k == 0 else (offset[k - 1], offset[k])
order.append(
torch.randperm(e_k - s_k, dtype=torch.int32, device=offset.device) + s_k
)
order = torch.cat(order, dim=0)
idx = torch.cuda.IntTensor(m, nsample).zero_()
dist2 = torch.cuda.FloatTensor(m, nsample).zero_()
random_ball_query_cuda(
m,
nsample,
min_radius,
max_radius,
order,
xyz,
new_xyz,
offset.int(),
new_offset.int(),
idx,
dist2,
)
return idx, torch.sqrt(dist2)
class BallQuery(Function):
"""Ball Query.
Find nearby points in spherical space.
"""
@staticmethod
def forward(
ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None
):
"""
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b)
output: idx: (m, nsample), dist2: (m, nsample)
"""
if new_xyz is None or new_offset is None:
new_xyz = xyz
new_offset = offset
assert xyz.is_contiguous() and new_xyz.is_contiguous()
assert min_radius < max_radius
m = new_xyz.shape[0]
idx = torch.cuda.IntTensor(m, nsample).zero_()
dist2 = torch.cuda.FloatTensor(m, nsample).zero_()
ball_query_cuda(
m,
nsample,
min_radius,
max_radius,
xyz,
new_xyz,
offset.int(),
new_offset.int(),
idx,
dist2,
)
return idx, torch.sqrt(dist2)
knn_query = KNNQuery.apply
ball_query = BallQuery.apply
random_ball_query = RandomBallQuery.apply