|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from voxelnext_3d_box.utils import centernet_utils |
|
import spconv.pytorch as spconv |
|
import copy |
|
from spconv.core import ConvAlgo |
|
|
|
|
|
class SeparateHead(nn.Module): |
|
def __init__(self, input_channels, sep_head_dict, kernel_size, use_bias=False): |
|
super().__init__() |
|
self.sep_head_dict = sep_head_dict |
|
|
|
for cur_name in self.sep_head_dict: |
|
output_channels = self.sep_head_dict[cur_name]['out_channels'] |
|
num_conv = self.sep_head_dict[cur_name]['num_conv'] |
|
|
|
fc_list = [] |
|
for k in range(num_conv - 1): |
|
fc_list.append(spconv.SparseSequential( |
|
spconv.SubMConv2d(input_channels, input_channels, kernel_size, padding=int(kernel_size//2), bias=use_bias, indice_key=cur_name, algo=ConvAlgo.Native), |
|
nn.BatchNorm1d(input_channels), |
|
nn.ReLU() |
|
)) |
|
fc_list.append(spconv.SubMConv2d(input_channels, output_channels, 1, bias=True, indice_key=cur_name+'out', algo=ConvAlgo.Native)) |
|
fc = nn.Sequential(*fc_list) |
|
self.__setattr__(cur_name, fc) |
|
|
|
def forward(self, x): |
|
ret_dict = {} |
|
for cur_name in self.sep_head_dict: |
|
ret_dict[cur_name] = self.__getattr__(cur_name)(x).features |
|
|
|
return ret_dict |
|
|
|
|
|
class VoxelNeXtHead(nn.Module): |
|
def __init__(self, class_names, point_cloud_range, voxel_size, kernel_size_head, |
|
CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING): |
|
super().__init__() |
|
self.point_cloud_range = torch.Tensor(point_cloud_range) |
|
self.voxel_size = torch.Tensor(voxel_size) |
|
self.feature_map_stride = 8 |
|
|
|
self.class_names = class_names |
|
self.class_names_each_head = [] |
|
self.class_id_mapping_each_head = [] |
|
self.POST_PROCESSING = POST_PROCESSING |
|
|
|
for cur_class_names in CLASS_NAMES_EACH_HEAD: |
|
self.class_names_each_head.append([x for x in cur_class_names if x in class_names]) |
|
cur_class_id_mapping = torch.from_numpy(np.array( |
|
[self.class_names.index(x) for x in cur_class_names if x in class_names] |
|
)) |
|
self.class_id_mapping_each_head.append(cur_class_id_mapping) |
|
|
|
total_classes = sum([len(x) for x in self.class_names_each_head]) |
|
assert total_classes == len(self.class_names), f'class_names_each_head={self.class_names_each_head}' |
|
|
|
self.heads_list = nn.ModuleList() |
|
self.separate_head_cfg = SEPARATE_HEAD_CFG |
|
for idx, cur_class_names in enumerate(self.class_names_each_head): |
|
cur_head_dict = copy.deepcopy(self.separate_head_cfg.HEAD_DICT) |
|
cur_head_dict['hm'] = dict(out_channels=len(cur_class_names), num_conv=2) |
|
self.heads_list.append( |
|
SeparateHead( |
|
input_channels=128, |
|
sep_head_dict=cur_head_dict, |
|
kernel_size=kernel_size_head, |
|
use_bias=True, |
|
) |
|
) |
|
self.forward_ret_dict = {} |
|
|
|
def generate_predicted_boxes(self, batch_size, pred_dicts, voxel_indices, spatial_shape): |
|
device = pred_dicts[0]['hm'].device |
|
post_process_cfg = self.POST_PROCESSING |
|
post_center_limit_range = torch.tensor(post_process_cfg.POST_CENTER_LIMIT_RANGE).float().to(device) |
|
|
|
ret_dict = [{ |
|
'pred_boxes': [], |
|
'pred_scores': [], |
|
'pred_labels': [], |
|
'pred_ious': [], |
|
'voxel_ids': [] |
|
} for k in range(batch_size)] |
|
for idx, pred_dict in enumerate(pred_dicts): |
|
batch_hm = pred_dict['hm'].sigmoid() |
|
batch_center = pred_dict['center'] |
|
batch_center_z = pred_dict['center_z'] |
|
batch_dim = pred_dict['dim'].exp() |
|
batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1) |
|
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1) |
|
batch_iou = None |
|
batch_vel = pred_dict['vel'] if 'vel' in self.separate_head_cfg.HEAD_ORDER else None |
|
voxel_indices_ = voxel_indices |
|
|
|
final_pred_dicts = centernet_utils.decode_bbox_from_voxels_nuscenes( |
|
batch_size=batch_size, indices=voxel_indices_, |
|
obj=batch_hm, |
|
rot_cos=batch_rot_cos, |
|
rot_sin=batch_rot_sin, |
|
center=batch_center, center_z=batch_center_z, |
|
dim=batch_dim, vel=batch_vel, iou=batch_iou, |
|
point_cloud_range=self.point_cloud_range.to(device), voxel_size=self.voxel_size.to(device), |
|
feature_map_stride=self.feature_map_stride, |
|
K=post_process_cfg.MAX_OBJ_PER_SAMPLE, |
|
score_thresh=post_process_cfg.SCORE_THRESH, |
|
post_center_limit_range=post_center_limit_range, |
|
add_features=torch.arange(voxel_indices_.shape[0], device=voxel_indices_.device).unsqueeze(-1) |
|
) |
|
|
|
for k, final_dict in enumerate(final_pred_dicts): |
|
class_id_mapping_each_head = self.class_id_mapping_each_head[idx].to(device) |
|
final_dict['pred_labels'] = class_id_mapping_each_head[final_dict['pred_labels'].long()] |
|
|
|
ret_dict[k]['pred_boxes'].append(final_dict['pred_boxes']) |
|
ret_dict[k]['pred_scores'].append(final_dict['pred_scores']) |
|
ret_dict[k]['pred_labels'].append(final_dict['pred_labels']) |
|
ret_dict[k]['pred_ious'].append(final_dict['pred_ious']) |
|
ret_dict[k]['voxel_ids'].append(final_dict['add_features']) |
|
|
|
for k in range(batch_size): |
|
pred_boxes = torch.cat(ret_dict[k]['pred_boxes'], dim=0) |
|
pred_scores = torch.cat(ret_dict[k]['pred_scores'], dim=0) |
|
pred_labels = torch.cat(ret_dict[k]['pred_labels'], dim=0) |
|
voxel_ids = torch.cat(ret_dict[k]['voxel_ids'], dim=0) |
|
|
|
ret_dict[k]['pred_boxes'] = pred_boxes |
|
ret_dict[k]['pred_scores'] = pred_scores |
|
ret_dict[k]['pred_labels'] = pred_labels + 1 |
|
ret_dict[k]['voxel_ids'] = voxel_ids |
|
|
|
return ret_dict |
|
|
|
def _get_voxel_infos(self, x): |
|
spatial_shape = x.spatial_shape |
|
voxel_indices = x.indices |
|
spatial_indices = [] |
|
num_voxels = [] |
|
batch_size = x.batch_size |
|
batch_index = voxel_indices[:, 0] |
|
|
|
for bs_idx in range(batch_size): |
|
batch_inds = batch_index==bs_idx |
|
spatial_indices.append(voxel_indices[batch_inds][:, [2, 1]]) |
|
num_voxels.append(batch_inds.sum()) |
|
|
|
return spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels |
|
|
|
def forward(self, data_dict): |
|
x = data_dict['encoded_spconv_tensor'] |
|
spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels = self._get_voxel_infos(x) |
|
|
|
pred_dicts = [] |
|
for idx, head in enumerate(self.heads_list): |
|
pred_dict = head(x) |
|
pred_dicts.append(pred_dict) |
|
|
|
pred_dicts = self.generate_predicted_boxes( |
|
data_dict['batch_size'], |
|
pred_dicts, voxel_indices, spatial_shape |
|
) |
|
|
|
return pred_dicts |
|
|