|
|
|
from typing import List, Tuple |
|
|
|
import torch |
|
from mmcv.cnn import Scale |
|
|
|
from mmdet.models.utils import multi_apply |
|
from mmdet.utils import reduce_mean |
|
|
|
from mmengine.model import BaseModule, bias_init_with_prob, normal_init |
|
from mmengine.structures import InstanceData |
|
from torch import Tensor, nn |
|
|
|
from mmdet3d.registry import MODELS, TASK_UTILS |
|
|
|
from mmdet3d.structures.det3d_data_sample import SampleList |
|
from mmdet3d.utils.typing_utils import (ConfigType, InstanceList, |
|
OptConfigType, OptInstanceList) |
|
|
|
|
|
@torch.no_grad() |
|
def get_points(n_voxels, voxel_size, origin): |
|
|
|
points = torch.stack( |
|
torch.meshgrid([ |
|
torch.arange(n_voxels[0]), |
|
torch.arange(n_voxels[1]), |
|
torch.arange(n_voxels[2]) |
|
])) |
|
new_origin = origin - n_voxels / 2. * voxel_size |
|
points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1) |
|
return points |
|
|
|
|
|
@MODELS.register_module() |
|
class NerfDetHead(BaseModule): |
|
r"""`ImVoxelNet<https://arxiv.org/abs/2106.01178>`_ head for indoor |
|
datasets. |
|
|
|
Args: |
|
n_classes (int): Number of classes. |
|
n_levels (int): Number of feature levels. |
|
n_channels (int): Number of channels in input tensors. |
|
n_reg_outs (int): Number of regression layer channels. |
|
pts_assign_threshold (int): Min number of location per box to |
|
be assigned with. |
|
pts_center_threshold (int): Max number of locations per box to |
|
be assigned with. |
|
center_loss (dict, optional): Config of centerness loss. |
|
Default: dict(type='CrossEntropyLoss', use_sigmoid=True). |
|
bbox_loss (dict, optional): Config of bbox loss. |
|
Default: dict(type='RotatedIoU3DLoss'). |
|
cls_loss (dict, optional): Config of classification loss. |
|
Default: dict(type='FocalLoss'). |
|
train_cfg (dict, optional): Config for train stage. Defaults to None. |
|
test_cfg (dict, optional): Config for test stage. Defaults to None. |
|
init_cfg (dict, optional): Config for weight initialization. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__(self, |
|
n_classes: int, |
|
n_levels: int, |
|
n_channels: int, |
|
n_reg_outs: int, |
|
pts_assign_threshold: int, |
|
pts_center_threshold: int, |
|
prior_generator: ConfigType, |
|
center_loss: ConfigType = dict( |
|
type='mmdet.CrossEntropyLoss', use_sigmoid=True), |
|
bbox_loss: ConfigType = dict(type='RotatedIoU3DLoss'), |
|
cls_loss: ConfigType = dict(type='mmdet.FocalLoss'), |
|
train_cfg: OptConfigType = None, |
|
test_cfg: OptConfigType = None, |
|
init_cfg: OptConfigType = None): |
|
super(NerfDetHead, self).__init__(init_cfg) |
|
self.n_classes = n_classes |
|
self.n_levels = n_levels |
|
self.n_reg_outs = n_reg_outs |
|
self.pts_assign_threshold = pts_assign_threshold |
|
self.pts_center_threshold = pts_center_threshold |
|
self.prior_generator = TASK_UTILS.build(prior_generator) |
|
self.center_loss = MODELS.build(center_loss) |
|
self.bbox_loss = MODELS.build(bbox_loss) |
|
self.cls_loss = MODELS.build(cls_loss) |
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
self._init_layers(n_channels, n_reg_outs, n_classes, n_levels) |
|
|
|
def _init_layers(self, n_channels, n_reg_outs, n_classes, n_levels): |
|
"""Initialize neural network layers of the head.""" |
|
self.conv_center = nn.Conv3d(n_channels, 1, 3, padding=1, bias=False) |
|
self.conv_reg = nn.Conv3d( |
|
n_channels, n_reg_outs, 3, padding=1, bias=False) |
|
self.conv_cls = nn.Conv3d(n_channels, n_classes, 3, padding=1) |
|
self.scales = nn.ModuleList([Scale(1.) for _ in range(n_levels)]) |
|
|
|
def init_weights(self): |
|
"""Initialize all layer weights.""" |
|
normal_init(self.conv_center, std=.01) |
|
normal_init(self.conv_reg, std=.01) |
|
normal_init(self.conv_cls, std=.01, bias=bias_init_with_prob(.01)) |
|
|
|
def _forward_single(self, x: Tensor, scale: Scale): |
|
"""Forward pass per level. |
|
|
|
Args: |
|
x (Tensor): Per level 3d neck output tensor. |
|
scale (mmcv.cnn.Scale): Per level multiplication weight. |
|
|
|
Returns: |
|
tuple[Tensor]: Centerness, bbox and classification predictions. |
|
""" |
|
return (self.conv_center(x), torch.exp(scale(self.conv_reg(x))), |
|
self.conv_cls(x)) |
|
|
|
def forward(self, x): |
|
return multi_apply(self._forward_single, x, self.scales) |
|
|
|
def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList, |
|
**kwargs) -> dict: |
|
"""Perform forward propagation and loss calculation of the detection |
|
head on the features of the upstream network. |
|
|
|
Args: |
|
x (tuple[Tensor]): Features from the upstream network, each is |
|
a 4D-tensor. |
|
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
|
|
|
Returns: |
|
dict: A dictionary of loss components. |
|
""" |
|
valid_pred = x[-1] |
|
outs = self(x[:-1]) |
|
|
|
batch_gt_instances_3d = [] |
|
batch_gt_instances_ignore = [] |
|
batch_input_metas = [] |
|
for data_sample in batch_data_samples: |
|
batch_input_metas.append(data_sample.metainfo) |
|
batch_gt_instances_3d.append(data_sample.gt_instances_3d) |
|
batch_gt_instances_ignore.append( |
|
data_sample.get('ignored_instances', None)) |
|
|
|
loss_inputs = outs + (valid_pred, batch_gt_instances_3d, |
|
batch_input_metas, batch_gt_instances_ignore) |
|
losses = self.loss_by_feat(*loss_inputs) |
|
return losses |
|
|
|
def loss_by_feat(self, |
|
center_preds: List[List[Tensor]], |
|
bbox_preds: List[List[Tensor]], |
|
cls_preds: List[List[Tensor]], |
|
valid_pred: Tensor, |
|
batch_gt_instances_3d: InstanceList, |
|
batch_input_metas: List[dict], |
|
batch_gt_instances_ignore: OptInstanceList = None, |
|
**kwargs) -> dict: |
|
"""Per scene loss function. |
|
|
|
Args: |
|
center_preds (list[list[Tensor]]): Centerness predictions for |
|
all scenes. The first list contains predictions from different |
|
levels. The second list contains predictions in a mini-batch. |
|
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. |
|
The first list contains predictions from different |
|
levels. The second list contains predictions in a mini-batch. |
|
cls_preds (list[list[Tensor]]): Classification predictions for all |
|
scenes. The first list contains predictions from different |
|
levels. The second list contains predictions in a mini-batch. |
|
valid_pred (Tensor): Valid mask prediction for all scenes. |
|
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of |
|
gt_instance_3d. It usually includes ``bboxes_3d``、` |
|
`labels_3d``、``depths``、``centers_2d`` and attributes. |
|
batch_input_metas (list[dict]): Meta information of each image, |
|
e.g., image size, scaling factor, etc. |
|
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): |
|
Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
|
data that is ignored during training and testing. |
|
Defaults to None. |
|
|
|
Returns: |
|
dict: Centerness, bbox, and classification loss values. |
|
""" |
|
valid_preds = self._upsample_valid_preds(valid_pred, center_preds) |
|
center_losses, bbox_losses, cls_losses = [], [], [] |
|
for i in range(len(batch_input_metas)): |
|
center_loss, bbox_loss, cls_loss = self._loss_by_feat_single( |
|
center_preds=[x[i] for x in center_preds], |
|
bbox_preds=[x[i] for x in bbox_preds], |
|
cls_preds=[x[i] for x in cls_preds], |
|
valid_preds=[x[i] for x in valid_preds], |
|
input_meta=batch_input_metas[i], |
|
gt_bboxes=batch_gt_instances_3d[i].bboxes_3d, |
|
gt_labels=batch_gt_instances_3d[i].labels_3d) |
|
center_losses.append(center_loss) |
|
bbox_losses.append(bbox_loss) |
|
cls_losses.append(cls_loss) |
|
return dict( |
|
center_loss=torch.mean(torch.stack(center_losses)), |
|
bbox_loss=torch.mean(torch.stack(bbox_losses)), |
|
cls_loss=torch.mean(torch.stack(cls_losses))) |
|
|
|
def _loss_by_feat_single(self, center_preds, bbox_preds, cls_preds, |
|
valid_preds, input_meta, gt_bboxes, gt_labels): |
|
featmap_sizes = [featmap.size()[-3:] for featmap in center_preds] |
|
points = self._get_points( |
|
featmap_sizes=featmap_sizes, |
|
origin=input_meta['lidar2img']['origin'], |
|
device=gt_bboxes.device) |
|
center_targets, bbox_targets, cls_targets = self._get_targets( |
|
points, gt_bboxes, gt_labels) |
|
|
|
center_preds = torch.cat( |
|
[x.permute(1, 2, 3, 0).reshape(-1) for x in center_preds]) |
|
bbox_preds = torch.cat([ |
|
x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in bbox_preds |
|
]) |
|
cls_preds = torch.cat( |
|
[x.permute(1, 2, 3, 0).reshape(-1, x.shape[0]) for x in cls_preds]) |
|
valid_preds = torch.cat( |
|
[x.permute(1, 2, 3, 0).reshape(-1) for x in valid_preds]) |
|
points = torch.cat(points) |
|
|
|
|
|
pos_inds = torch.nonzero( |
|
torch.logical_and(cls_targets >= 0, valid_preds)).squeeze(1) |
|
n_pos = points.new_tensor(len(pos_inds)) |
|
n_pos = max(reduce_mean(n_pos), 1.) |
|
if torch.any(valid_preds): |
|
cls_loss = self.cls_loss( |
|
cls_preds[valid_preds], |
|
cls_targets[valid_preds], |
|
avg_factor=n_pos) |
|
else: |
|
cls_loss = cls_preds[valid_preds].sum() |
|
|
|
|
|
pos_center_preds = center_preds[pos_inds] |
|
pos_bbox_preds = bbox_preds[pos_inds] |
|
if len(pos_inds) > 0: |
|
pos_center_targets = center_targets[pos_inds] |
|
pos_bbox_targets = bbox_targets[pos_inds] |
|
pos_points = points[pos_inds] |
|
center_loss = self.center_loss( |
|
pos_center_preds, pos_center_targets, avg_factor=n_pos) |
|
bbox_loss = self.bbox_loss( |
|
self._bbox_pred_to_bbox(pos_points, pos_bbox_preds), |
|
pos_bbox_targets, |
|
weight=pos_center_targets, |
|
avg_factor=pos_center_targets.sum()) |
|
else: |
|
center_loss = pos_center_preds.sum() |
|
bbox_loss = pos_bbox_preds.sum() |
|
return center_loss, bbox_loss, cls_loss |
|
|
|
def predict(self, |
|
x: Tuple[Tensor], |
|
batch_data_samples: SampleList, |
|
rescale: bool = False) -> InstanceList: |
|
"""Perform forward propagation of the 3D detection head and predict |
|
detection results on the features of the upstream network. |
|
|
|
Args: |
|
x (tuple[Tensor]): Multi-level features from the |
|
upstream network, each is a 4D-tensor. |
|
batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data |
|
Samples. It usually includes information such as |
|
`gt_instance_3d`, `gt_pts_panoptic_seg` and |
|
`gt_pts_sem_seg`. |
|
rescale (bool, optional): Whether to rescale the results. |
|
Defaults to False. |
|
|
|
Returns: |
|
list[:obj:`InstanceData`]: Detection results of each sample |
|
after the post process. |
|
Each item usually contains following keys. |
|
|
|
- scores_3d (Tensor): Classification scores, has a shape |
|
(num_instances, ) |
|
- labels_3d (Tensor): Labels of bboxes, has a shape |
|
(num_instances, ). |
|
- bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, |
|
contains a tensor with shape (num_instances, C), where |
|
C >= 6. |
|
""" |
|
batch_input_metas = [ |
|
data_samples.metainfo for data_samples in batch_data_samples |
|
] |
|
valid_pred = x[-1] |
|
outs = self(x[:-1]) |
|
predictions = self.predict_by_feat( |
|
*outs, |
|
valid_pred=valid_pred, |
|
batch_input_metas=batch_input_metas, |
|
rescale=rescale) |
|
return predictions |
|
|
|
def predict_by_feat(self, center_preds: List[List[Tensor]], |
|
bbox_preds: List[List[Tensor]], |
|
cls_preds: List[List[Tensor]], valid_pred: Tensor, |
|
batch_input_metas: List[dict], |
|
**kwargs) -> List[InstanceData]: |
|
"""Generate boxes for all scenes. |
|
|
|
Args: |
|
center_preds (list[list[Tensor]]): Centerness predictions for |
|
all scenes. |
|
bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. |
|
cls_preds (list[list[Tensor]]): Classification predictions for all |
|
scenes. |
|
valid_pred (Tensor): Valid mask prediction for all scenes. |
|
batch_input_metas (list[dict]): Meta infos for all scenes. |
|
|
|
Returns: |
|
list[tuple[Tensor]]: Predicted bboxes, scores, and labels for |
|
all scenes. |
|
""" |
|
valid_preds = self._upsample_valid_preds(valid_pred, center_preds) |
|
results = [] |
|
for i in range(len(batch_input_metas)): |
|
results.append( |
|
self._predict_by_feat_single( |
|
center_preds=[x[i] for x in center_preds], |
|
bbox_preds=[x[i] for x in bbox_preds], |
|
cls_preds=[x[i] for x in cls_preds], |
|
valid_preds=[x[i] for x in valid_preds], |
|
input_meta=batch_input_metas[i])) |
|
return results |
|
|
|
def _predict_by_feat_single(self, center_preds: List[Tensor], |
|
bbox_preds: List[Tensor], |
|
cls_preds: List[Tensor], |
|
valid_preds: List[Tensor], |
|
input_meta: dict) -> InstanceData: |
|
"""Generate boxes for single sample. |
|
|
|
Args: |
|
center_preds (list[Tensor]): Centerness predictions for all levels. |
|
bbox_preds (list[Tensor]): Bbox predictions for all levels. |
|
cls_preds (list[Tensor]): Classification predictions for all |
|
levels. |
|
valid_preds (tuple[Tensor]): Upsampled valid masks for all feature |
|
levels. |
|
input_meta (dict): Scene meta info. |
|
|
|
Returns: |
|
tuple[Tensor]: Predicted bounding boxes, scores and labels. |
|
""" |
|
featmap_sizes = [featmap.size()[-3:] for featmap in center_preds] |
|
points = self._get_points( |
|
featmap_sizes=featmap_sizes, |
|
origin=input_meta['lidar2img']['origin'], |
|
device=center_preds[0].device) |
|
mlvl_bboxes, mlvl_scores = [], [] |
|
for center_pred, bbox_pred, cls_pred, valid_pred, point in zip( |
|
center_preds, bbox_preds, cls_preds, valid_preds, points): |
|
center_pred = center_pred.permute(1, 2, 3, 0).reshape(-1, 1) |
|
bbox_pred = bbox_pred.permute(1, 2, 3, |
|
0).reshape(-1, bbox_pred.shape[0]) |
|
cls_pred = cls_pred.permute(1, 2, 3, |
|
0).reshape(-1, cls_pred.shape[0]) |
|
valid_pred = valid_pred.permute(1, 2, 3, 0).reshape(-1, 1) |
|
scores = cls_pred.sigmoid() * center_pred.sigmoid() * valid_pred |
|
max_scores, _ = scores.max(dim=1) |
|
|
|
if len(scores) > self.test_cfg.nms_pre > 0: |
|
_, ids = max_scores.topk(self.test_cfg.nms_pre) |
|
bbox_pred = bbox_pred[ids] |
|
scores = scores[ids] |
|
point = point[ids] |
|
|
|
bboxes = self._bbox_pred_to_bbox(point, bbox_pred) |
|
mlvl_bboxes.append(bboxes) |
|
mlvl_scores.append(scores) |
|
|
|
bboxes = torch.cat(mlvl_bboxes) |
|
scores = torch.cat(mlvl_scores) |
|
bboxes, scores, labels = self._nms(bboxes, scores, input_meta) |
|
|
|
bboxes = input_meta['box_type_3d']( |
|
bboxes, box_dim=6, with_yaw=False, origin=(.5, .5, .5)) |
|
|
|
results = InstanceData() |
|
results.bboxes_3d = bboxes |
|
results.scores_3d = scores |
|
results.labels_3d = labels |
|
return results |
|
|
|
@staticmethod |
|
def _upsample_valid_preds(valid_pred, features): |
|
"""Upsample valid mask predictions. |
|
|
|
Args: |
|
valid_pred (Tensor): Valid mask prediction. |
|
features (Tensor): Feature tensor. |
|
|
|
Returns: |
|
tuple[Tensor]: Upsampled valid masks for all feature levels. |
|
""" |
|
return [ |
|
nn.Upsample(size=x.shape[-3:], |
|
mode='trilinear')(valid_pred).round().bool() |
|
for x in features |
|
] |
|
|
|
@torch.no_grad() |
|
def _get_points(self, featmap_sizes, origin, device): |
|
mlvl_points = [] |
|
tmp_voxel_size = [.16, .16, .2] |
|
for i, featmap_size in enumerate(featmap_sizes): |
|
mlvl_points.append( |
|
get_points( |
|
n_voxels=torch.tensor(featmap_size), |
|
voxel_size=torch.tensor(tmp_voxel_size) * (2**i), |
|
origin=torch.tensor(origin)).reshape(3, -1).transpose( |
|
0, 1).to(device)) |
|
return mlvl_points |
|
|
|
def _bbox_pred_to_bbox(self, points, bbox_pred): |
|
return torch.stack([ |
|
points[:, 0] - bbox_pred[:, 0], points[:, 1] - bbox_pred[:, 2], |
|
points[:, 2] - bbox_pred[:, 4], points[:, 0] + bbox_pred[:, 1], |
|
points[:, 1] + bbox_pred[:, 3], points[:, 2] + bbox_pred[:, 5] |
|
], -1) |
|
|
|
def _bbox_pred_to_loss(self, points, bbox_preds): |
|
return self._bbox_pred_to_bbox(points, bbox_preds) |
|
|
|
|
|
@staticmethod |
|
def _get_face_distances(points, boxes): |
|
"""Calculate distances from point to box faces. |
|
|
|
Args: |
|
points (Tensor): Final locations of shape (N_points, N_boxes, 3). |
|
boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7) |
|
|
|
Returns: |
|
Tensor: Face distances of shape (N_points, N_boxes, 6), |
|
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). |
|
""" |
|
dx_min = points[..., 0] - boxes[..., 0] + boxes[..., 3] / 2 |
|
dx_max = boxes[..., 0] + boxes[..., 3] / 2 - points[..., 0] |
|
dy_min = points[..., 1] - boxes[..., 1] + boxes[..., 4] / 2 |
|
dy_max = boxes[..., 1] + boxes[..., 4] / 2 - points[..., 1] |
|
dz_min = points[..., 2] - boxes[..., 2] + boxes[..., 5] / 2 |
|
dz_max = boxes[..., 2] + boxes[..., 5] / 2 - points[..., 2] |
|
return torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max), |
|
dim=-1) |
|
|
|
@staticmethod |
|
def _get_centerness(face_distances): |
|
"""Compute point centerness w.r.t containing box. |
|
|
|
Args: |
|
face_distances (Tensor): Face distances of shape (B, N, 6), |
|
(dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). |
|
|
|
Returns: |
|
Tensor: Centerness of shape (B, N). |
|
""" |
|
x_dims = face_distances[..., [0, 1]] |
|
y_dims = face_distances[..., [2, 3]] |
|
z_dims = face_distances[..., [4, 5]] |
|
centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \ |
|
y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \ |
|
z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0] |
|
return torch.sqrt(centerness_targets) |
|
|
|
@torch.no_grad() |
|
def _get_targets(self, points, gt_bboxes, gt_labels): |
|
"""Compute targets for final locations for a single scene. |
|
|
|
Args: |
|
points (list[Tensor]): Final locations for all levels. |
|
gt_bboxes (BaseInstance3DBoxes): Ground truth boxes. |
|
gt_labels (Tensor): Ground truth labels. |
|
|
|
Returns: |
|
tuple[Tensor]: Centerness, bbox and classification |
|
targets for all locations. |
|
""" |
|
float_max = 1e8 |
|
expanded_scales = [ |
|
points[i].new_tensor(i).expand(len(points[i])).to(gt_labels.device) |
|
for i in range(len(points)) |
|
] |
|
points = torch.cat(points, dim=0).to(gt_labels.device) |
|
scales = torch.cat(expanded_scales, dim=0) |
|
|
|
|
|
n_points = len(points) |
|
n_boxes = len(gt_bboxes) |
|
volumes = gt_bboxes.volume.to(points.device) |
|
volumes = volumes.expand(n_points, n_boxes).contiguous() |
|
gt_bboxes = torch.cat( |
|
(gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:6]), dim=1) |
|
gt_bboxes = gt_bboxes.to(points.device).expand(n_points, n_boxes, 6) |
|
expanded_points = points.unsqueeze(1).expand(n_points, n_boxes, 3) |
|
bbox_targets = self._get_face_distances(expanded_points, gt_bboxes) |
|
|
|
|
|
inside_gt_bbox_mask = bbox_targets[..., :6].min( |
|
-1)[0] > 0 |
|
|
|
|
|
|
|
n_pos_points_per_scale = [] |
|
for i in range(self.n_levels): |
|
n_pos_points_per_scale.append( |
|
torch.sum(inside_gt_bbox_mask[scales == i], dim=0)) |
|
|
|
n_pos_points_per_scale = torch.stack(n_pos_points_per_scale, dim=0) |
|
lower_limit_mask = n_pos_points_per_scale < self.pts_assign_threshold |
|
|
|
extra = torch.arange(self.n_levels, 0, -1).unsqueeze(1).expand( |
|
self.n_levels, n_boxes).to(lower_limit_mask.device) |
|
lower_index = torch.argmax(lower_limit_mask.int() * extra, dim=0) - 1 |
|
lower_index = torch.where(lower_index < 0, |
|
torch.zeros_like(lower_index), lower_index) |
|
all_upper_limit_mask = torch.all( |
|
torch.logical_not(lower_limit_mask), dim=0) |
|
best_scale = torch.where( |
|
all_upper_limit_mask, |
|
torch.ones_like(all_upper_limit_mask) * self.n_levels - 1, |
|
lower_index) |
|
|
|
best_scale = torch.unsqueeze(best_scale, 0).expand(n_points, n_boxes) |
|
scales = torch.unsqueeze(scales, 1).expand(n_points, n_boxes) |
|
inside_best_scale_mask = best_scale == scales |
|
|
|
|
|
centerness = self._get_centerness(bbox_targets) |
|
centerness = torch.where(inside_gt_bbox_mask, centerness, |
|
torch.ones_like(centerness) * -1) |
|
centerness = torch.where(inside_best_scale_mask, centerness, |
|
torch.ones_like(centerness) * -1) |
|
top_centerness = torch.topk( |
|
centerness, self.pts_center_threshold + 1, dim=0).values[-1] |
|
inside_top_centerness_mask = centerness > top_centerness.unsqueeze(0) |
|
|
|
|
|
|
|
volumes = torch.where(inside_gt_bbox_mask, volumes, |
|
torch.ones_like(volumes) * float_max) |
|
volumes = torch.where(inside_best_scale_mask, volumes, |
|
torch.ones_like(volumes) * float_max) |
|
volumes = torch.where(inside_top_centerness_mask, volumes, |
|
torch.ones_like(volumes) * float_max) |
|
min_area, min_area_inds = volumes.min(dim=1) |
|
|
|
labels = gt_labels[min_area_inds] |
|
labels = torch.where(min_area == float_max, |
|
torch.ones_like(labels) * -1, labels) |
|
bbox_targets = bbox_targets[range(n_points), min_area_inds] |
|
centerness_targets = self._get_centerness(bbox_targets) |
|
|
|
return centerness_targets, self._bbox_pred_to_bbox( |
|
points, bbox_targets), labels |
|
|
|
def _nms(self, bboxes, scores, img_meta): |
|
scores, labels = scores.max(dim=1) |
|
ids = scores > self.test_cfg.score_thr |
|
bboxes = bboxes[ids] |
|
scores = scores[ids] |
|
labels = labels[ids] |
|
ids = self.aligned_3d_nms(bboxes, scores, labels, |
|
self.test_cfg.iou_thr) |
|
bboxes = bboxes[ids] |
|
bboxes = torch.stack( |
|
((bboxes[:, 0] + bboxes[:, 3]) / 2., |
|
(bboxes[:, 1] + bboxes[:, 4]) / 2., |
|
(bboxes[:, 2] + bboxes[:, 5]) / 2., bboxes[:, 3] - bboxes[:, 0], |
|
bboxes[:, 4] - bboxes[:, 1], bboxes[:, 5] - bboxes[:, 2]), |
|
dim=1) |
|
return bboxes, scores[ids], labels[ids] |
|
|
|
@staticmethod |
|
def aligned_3d_nms(boxes, scores, classes, thresh): |
|
"""3d nms for aligned boxes. |
|
|
|
Args: |
|
boxes (torch.Tensor): Aligned box with shape [n, 6]. |
|
scores (torch.Tensor): Scores of each box. |
|
classes (torch.Tensor): Class of each box. |
|
thresh (float): Iou threshold for nms. |
|
|
|
Returns: |
|
torch.Tensor: Indices of selected boxes. |
|
""" |
|
x1 = boxes[:, 0] |
|
y1 = boxes[:, 1] |
|
z1 = boxes[:, 2] |
|
x2 = boxes[:, 3] |
|
y2 = boxes[:, 4] |
|
z2 = boxes[:, 5] |
|
area = (x2 - x1) * (y2 - y1) * (z2 - z1) |
|
zero = boxes.new_zeros(1, ) |
|
|
|
score_sorted = torch.argsort(scores) |
|
pick = [] |
|
while (score_sorted.shape[0] != 0): |
|
last = score_sorted.shape[0] |
|
i = score_sorted[-1] |
|
pick.append(i) |
|
|
|
xx1 = torch.max(x1[i], x1[score_sorted[:last - 1]]) |
|
yy1 = torch.max(y1[i], y1[score_sorted[:last - 1]]) |
|
zz1 = torch.max(z1[i], z1[score_sorted[:last - 1]]) |
|
xx2 = torch.min(x2[i], x2[score_sorted[:last - 1]]) |
|
yy2 = torch.min(y2[i], y2[score_sorted[:last - 1]]) |
|
zz2 = torch.min(z2[i], z2[score_sorted[:last - 1]]) |
|
classes1 = classes[i] |
|
classes2 = classes[score_sorted[:last - 1]] |
|
inter_l = torch.max(zero, xx2 - xx1) |
|
inter_w = torch.max(zero, yy2 - yy1) |
|
inter_h = torch.max(zero, zz2 - zz1) |
|
|
|
inter = inter_l * inter_w * inter_h |
|
iou = inter / (area[i] + area[score_sorted[:last - 1]] - inter) |
|
iou = iou * (classes1 == classes2).float() |
|
score_sorted = score_sorted[torch.nonzero( |
|
iou <= thresh, as_tuple=False).flatten()] |
|
|
|
indices = boxes.new_tensor(pick, dtype=torch.long) |
|
return indices |
|
|