Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.autograd import Function | |
from pointops._C import knn_query_cuda, random_ball_query_cuda, ball_query_cuda | |
class KNNQuery(Function): | |
def forward(ctx, nsample, xyz, offset, new_xyz=None, new_offset=None): | |
""" | |
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) | |
output: idx: (m, nsample) -1 is placeholder, dist2: (m, nsample) | |
""" | |
if new_xyz is None or new_offset is None: | |
new_xyz = xyz | |
new_offset = offset | |
assert xyz.is_contiguous() and new_xyz.is_contiguous() | |
m = new_xyz.shape[0] | |
idx = torch.cuda.IntTensor(m, nsample).zero_() | |
dist2 = torch.cuda.FloatTensor(m, nsample).zero_() | |
knn_query_cuda( | |
m, nsample, xyz, new_xyz, offset.int(), new_offset.int(), idx, dist2 | |
) | |
return idx, torch.sqrt(dist2) | |
class RandomBallQuery(Function): | |
"""Random Ball Query. | |
Find nearby points in spherical space. | |
""" | |
def forward( | |
ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None | |
): | |
""" | |
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) | |
output: idx: (m, nsample), dist2: (m, nsample) | |
""" | |
if new_xyz is None or new_offset is None: | |
new_xyz = xyz | |
new_offset = offset | |
assert xyz.is_contiguous() and new_xyz.is_contiguous() | |
assert min_radius < max_radius | |
m = new_xyz.shape[0] | |
order = [] | |
for k in range(offset.shape[0]): | |
s_k, e_k = (0, offset[0]) if k == 0 else (offset[k - 1], offset[k]) | |
order.append( | |
torch.randperm(e_k - s_k, dtype=torch.int32, device=offset.device) + s_k | |
) | |
order = torch.cat(order, dim=0) | |
idx = torch.cuda.IntTensor(m, nsample).zero_() | |
dist2 = torch.cuda.FloatTensor(m, nsample).zero_() | |
random_ball_query_cuda( | |
m, | |
nsample, | |
min_radius, | |
max_radius, | |
order, | |
xyz, | |
new_xyz, | |
offset.int(), | |
new_offset.int(), | |
idx, | |
dist2, | |
) | |
return idx, torch.sqrt(dist2) | |
class BallQuery(Function): | |
"""Ball Query. | |
Find nearby points in spherical space. | |
""" | |
def forward( | |
ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None | |
): | |
""" | |
input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) | |
output: idx: (m, nsample), dist2: (m, nsample) | |
""" | |
if new_xyz is None or new_offset is None: | |
new_xyz = xyz | |
new_offset = offset | |
assert xyz.is_contiguous() and new_xyz.is_contiguous() | |
assert min_radius < max_radius | |
m = new_xyz.shape[0] | |
idx = torch.cuda.IntTensor(m, nsample).zero_() | |
dist2 = torch.cuda.FloatTensor(m, nsample).zero_() | |
ball_query_cuda( | |
m, | |
nsample, | |
min_radius, | |
max_radius, | |
xyz, | |
new_xyz, | |
offset.int(), | |
new_offset.int(), | |
idx, | |
dist2, | |
) | |
return idx, torch.sqrt(dist2) | |
knn_query = KNNQuery.apply | |
ball_query = BallQuery.apply | |
random_ball_query = RandomBallQuery.apply | |