|
import torch |
|
from mmcv.utils import ext_loader |
|
|
|
ext_module = ext_loader.load_ext('_ext', ['iou3d_nms3d_forward']) |
|
|
|
|
|
def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): |
|
"""NMS function GPU implementation (using IoU3D). The difference between |
|
this implementation and nms3d in MMCV is that we add `pre_maxsize` and |
|
`post_max_size` before and after NMS respectively. |
|
|
|
Args: |
|
boxes (Tensor): Input boxes with the shape of [N, 7] |
|
([cx, cy, cz, l, w, h, theta]). |
|
scores (Tensor): Scores of boxes with the shape of [N]. |
|
thresh (float): Overlap threshold of NMS. |
|
pre_max_size (int, optional): Max size of boxes before NMS. |
|
Defaults to None. |
|
post_max_size (int, optional): Max size of boxes after NMS. |
|
Defaults to None. |
|
|
|
Returns: |
|
Tensor: Indexes after NMS. |
|
""" |
|
|
|
assert boxes.size(1) == 7, 'Input boxes shape should be (N, 7)' |
|
order = scores.sort(0, descending=True)[1] |
|
if pre_maxsize is not None: |
|
order = order[:pre_maxsize] |
|
boxes = boxes[order].contiguous() |
|
|
|
keep = boxes.new_zeros(boxes.size(0), dtype=torch.long) |
|
num_out = boxes.new_zeros(size=(), dtype=torch.long) |
|
ext_module.iou3d_nms3d_forward( |
|
boxes, keep, num_out, nms_overlap_thresh=thresh) |
|
keep = order[keep[:num_out].to(boxes.device)].contiguous() |
|
|
|
if post_max_size is not None: |
|
keep = keep[:post_max_size] |
|
|
|
return keep |
|
|