|
from torch.autograd import Function
|
|
|
|
from ..utils import ext_loader
|
|
|
|
ext_module = ext_loader.load_ext(
|
|
'_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
|
|
|
|
|
|
class AssignScoreWithK(Function):
|
|
r"""Perform weighted sum to generate output features according to scores.
|
|
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
|
|
scene_seg/lib/paconv_lib/src/gpu>`_.
|
|
|
|
This is a memory-efficient CUDA implementation of assign_scores operation,
|
|
which first transform all point features with weight bank, then assemble
|
|
neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
|
|
|
|
See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
|
|
more detailed descriptions.
|
|
|
|
Note:
|
|
This implementation assumes using ``neighbor`` kernel input, which is
|
|
(point_features - center_features, point_features).
|
|
See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
|
|
pointnet2/paconv.py#L128 for more details.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx,
|
|
scores,
|
|
point_features,
|
|
center_features,
|
|
knn_idx,
|
|
aggregate='sum'):
|
|
"""
|
|
Args:
|
|
scores (torch.Tensor): (B, npoint, K, M), predicted scores to
|
|
aggregate weight matrices in the weight bank.
|
|
``npoint`` is the number of sampled centers.
|
|
``K`` is the number of queried neighbors.
|
|
``M`` is the number of weight matrices in the weight bank.
|
|
point_features (torch.Tensor): (B, N, M, out_dim)
|
|
Pre-computed point features to be aggregated.
|
|
center_features (torch.Tensor): (B, N, M, out_dim)
|
|
Pre-computed center features to be aggregated.
|
|
knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
|
|
We assume the first idx in each row is the idx of the center.
|
|
aggregate (str, optional): Aggregation method.
|
|
Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
|
|
|
|
Returns:
|
|
torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
|
|
"""
|
|
agg = {'sum': 0, 'avg': 1, 'max': 2}
|
|
|
|
B, N, M, out_dim = point_features.size()
|
|
_, npoint, K, _ = scores.size()
|
|
|
|
output = point_features.new_zeros((B, out_dim, npoint, K))
|
|
ext_module.assign_score_withk_forward(
|
|
point_features.contiguous(),
|
|
center_features.contiguous(),
|
|
scores.contiguous(),
|
|
knn_idx.contiguous(),
|
|
output,
|
|
B=B,
|
|
N0=N,
|
|
N1=npoint,
|
|
M=M,
|
|
K=K,
|
|
O=out_dim,
|
|
aggregate=agg[aggregate])
|
|
|
|
ctx.save_for_backward(output, point_features, center_features, scores,
|
|
knn_idx)
|
|
ctx.agg = agg[aggregate]
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
"""
|
|
Args:
|
|
grad_out (torch.Tensor): (B, out_dim, npoint, K)
|
|
|
|
Returns:
|
|
grad_scores (torch.Tensor): (B, npoint, K, M)
|
|
grad_point_features (torch.Tensor): (B, N, M, out_dim)
|
|
grad_center_features (torch.Tensor): (B, N, M, out_dim)
|
|
"""
|
|
_, point_features, center_features, scores, knn_idx = ctx.saved_tensors
|
|
|
|
agg = ctx.agg
|
|
|
|
B, N, M, out_dim = point_features.size()
|
|
_, npoint, K, _ = scores.size()
|
|
|
|
grad_point_features = point_features.new_zeros(point_features.shape)
|
|
grad_center_features = center_features.new_zeros(center_features.shape)
|
|
grad_scores = scores.new_zeros(scores.shape)
|
|
|
|
ext_module.assign_score_withk_backward(
|
|
grad_out.contiguous(),
|
|
point_features.contiguous(),
|
|
center_features.contiguous(),
|
|
scores.contiguous(),
|
|
knn_idx.contiguous(),
|
|
grad_point_features,
|
|
grad_center_features,
|
|
grad_scores,
|
|
B=B,
|
|
N0=N,
|
|
N1=npoint,
|
|
M=M,
|
|
K=K,
|
|
O=out_dim,
|
|
aggregate=agg)
|
|
|
|
return grad_scores, grad_point_features, \
|
|
grad_center_features, None, None
|
|
|
|
|
|
assign_score_withk = AssignScoreWithK.apply
|
|
|