Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,493 Bytes
4893ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
from pointops import knn_query, ball_query, grouping
def knn_query_and_group(
feat,
xyz,
offset=None,
new_xyz=None,
new_offset=None,
idx=None,
nsample=None,
with_xyz=False,
):
if idx is None:
assert nsample is not None
idx, _ = knn_query(nsample, xyz, offset, new_xyz, new_offset)
return grouping(idx, feat, xyz, new_xyz, with_xyz), idx
def ball_query_and_group(
feat,
xyz,
offset=None,
new_xyz=None,
new_offset=None,
idx=None,
max_radio=None,
min_radio=0,
nsample=None,
with_xyz=False,
):
if idx is None:
assert nsample is not None and offset is not None
assert max_radio is not None and min_radio is not None
idx, _ = ball_query(
nsample, max_radio, min_radio, xyz, offset, new_xyz, new_offset
)
return grouping(idx, feat, xyz, new_xyz, with_xyz), idx
def query_and_group(
nsample,
xyz,
new_xyz,
feat,
idx,
offset,
new_offset,
dilation=0,
with_feat=True,
with_xyz=True,
):
"""
input: coords: (n, 3), new_xyz: (m, 3), color: (n, c), idx: (m, nsample), offset: (b), new_offset: (b)
output: new_feat: (m, nsample, c+3), grouped_idx: (m, nsample)
"""
assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous()
if new_xyz is None:
new_xyz = xyz
if idx is None:
num_samples_total = 1 + (nsample - 1) * (dilation + 1)
# num points in a batch might < num_samples_total => [n1, n2, ..., nk, ns, ns, ns, ...]
idx_no_dilation, _ = knn_query(
num_samples_total, xyz, offset, new_xyz, new_offset
) # (m, nsample * (d + 1))
idx = []
batch_end = offset.tolist()
batch_start = [0] + batch_end[:-1]
new_batch_end = new_offset.tolist()
new_batch_start = [0] + new_batch_end[:-1]
for i in range(offset.shape[0]):
if batch_end[i] - batch_start[i] < num_samples_total:
soft_dilation = (batch_end[i] - batch_start[i] - 1) / (nsample - 1) - 1
else:
soft_dilation = dilation
idx.append(
idx_no_dilation[
new_batch_start[i] : new_batch_end[i],
[int((soft_dilation + 1) * i) for i in range(nsample)],
]
)
idx = torch.cat(idx, dim=0)
if not with_feat:
return idx
n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1]
grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3)
# grouped_xyz = grouping(coords, idx) # (m, nsample, 3)
grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3)
grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c)
# grouped_feat = grouping(color, idx) # (m, nsample, c)
if with_xyz:
return torch.cat((grouped_xyz, grouped_feat), -1), idx # (m, nsample, 3+c)
else:
return grouped_feat, idx
def offset2batch(offset):
return (
torch.cat(
[
(
torch.tensor([i] * (o - offset[i - 1]))
if i > 0
else torch.tensor([i] * o)
)
for i, o in enumerate(offset)
],
dim=0,
)
.long()
.to(offset.device)
)
def batch2offset(batch):
return torch.cumsum(batch.bincount(), dim=0).int()
|