File size: 7,275 Bytes
b793f0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import numpy as np
import torch
import torch.nn as nn
from voxelnext_3d_box.utils import centernet_utils
import spconv.pytorch as spconv
import copy
from spconv.core import ConvAlgo
class SeparateHead(nn.Module):
def __init__(self, input_channels, sep_head_dict, kernel_size, use_bias=False):
super().__init__()
self.sep_head_dict = sep_head_dict
for cur_name in self.sep_head_dict:
output_channels = self.sep_head_dict[cur_name]['out_channels']
num_conv = self.sep_head_dict[cur_name]['num_conv']
fc_list = []
for k in range(num_conv - 1):
fc_list.append(spconv.SparseSequential(
spconv.SubMConv2d(input_channels, input_channels, kernel_size, padding=int(kernel_size//2), bias=use_bias, indice_key=cur_name, algo=ConvAlgo.Native),
nn.BatchNorm1d(input_channels),
nn.ReLU()
))
fc_list.append(spconv.SubMConv2d(input_channels, output_channels, 1, bias=True, indice_key=cur_name+'out', algo=ConvAlgo.Native))
fc = nn.Sequential(*fc_list)
self.__setattr__(cur_name, fc)
def forward(self, x):
ret_dict = {}
for cur_name in self.sep_head_dict:
ret_dict[cur_name] = self.__getattr__(cur_name)(x).features
return ret_dict
class VoxelNeXtHead(nn.Module):
def __init__(self, class_names, point_cloud_range, voxel_size, kernel_size_head,
CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING):
super().__init__()
self.point_cloud_range = torch.Tensor(point_cloud_range)
self.voxel_size = torch.Tensor(voxel_size)
self.feature_map_stride = 8
self.class_names = class_names
self.class_names_each_head = []
self.class_id_mapping_each_head = []
self.POST_PROCESSING = POST_PROCESSING
for cur_class_names in CLASS_NAMES_EACH_HEAD:
self.class_names_each_head.append([x for x in cur_class_names if x in class_names])
cur_class_id_mapping = torch.from_numpy(np.array(
[self.class_names.index(x) for x in cur_class_names if x in class_names]
))
self.class_id_mapping_each_head.append(cur_class_id_mapping)
total_classes = sum([len(x) for x in self.class_names_each_head])
assert total_classes == len(self.class_names), f'class_names_each_head={self.class_names_each_head}'
self.heads_list = nn.ModuleList()
self.separate_head_cfg = SEPARATE_HEAD_CFG
for idx, cur_class_names in enumerate(self.class_names_each_head):
cur_head_dict = copy.deepcopy(self.separate_head_cfg.HEAD_DICT)
cur_head_dict['hm'] = dict(out_channels=len(cur_class_names), num_conv=2)
self.heads_list.append(
SeparateHead(
input_channels=128,
sep_head_dict=cur_head_dict,
kernel_size=kernel_size_head,
use_bias=True,
)
)
self.forward_ret_dict = {}
def generate_predicted_boxes(self, batch_size, pred_dicts, voxel_indices, spatial_shape):
device = pred_dicts[0]['hm'].device
post_process_cfg = self.POST_PROCESSING
post_center_limit_range = torch.tensor(post_process_cfg.POST_CENTER_LIMIT_RANGE).float().to(device)
ret_dict = [{
'pred_boxes': [],
'pred_scores': [],
'pred_labels': [],
'pred_ious': [],
'voxel_ids': []
} for k in range(batch_size)]
for idx, pred_dict in enumerate(pred_dicts):
batch_hm = pred_dict['hm'].sigmoid()
batch_center = pred_dict['center']
batch_center_z = pred_dict['center_z']
batch_dim = pred_dict['dim'].exp()
batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1)
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1)
batch_iou = None
batch_vel = pred_dict['vel'] if 'vel' in self.separate_head_cfg.HEAD_ORDER else None
voxel_indices_ = voxel_indices
final_pred_dicts = centernet_utils.decode_bbox_from_voxels_nuscenes(
batch_size=batch_size, indices=voxel_indices_,
obj=batch_hm,
rot_cos=batch_rot_cos,
rot_sin=batch_rot_sin,
center=batch_center, center_z=batch_center_z,
dim=batch_dim, vel=batch_vel, iou=batch_iou,
point_cloud_range=self.point_cloud_range.to(device), voxel_size=self.voxel_size.to(device),
feature_map_stride=self.feature_map_stride,
K=post_process_cfg.MAX_OBJ_PER_SAMPLE,
score_thresh=post_process_cfg.SCORE_THRESH,
post_center_limit_range=post_center_limit_range,
add_features=torch.arange(voxel_indices_.shape[0], device=voxel_indices_.device).unsqueeze(-1)
)
for k, final_dict in enumerate(final_pred_dicts):
class_id_mapping_each_head = self.class_id_mapping_each_head[idx].to(device)
final_dict['pred_labels'] = class_id_mapping_each_head[final_dict['pred_labels'].long()]
ret_dict[k]['pred_boxes'].append(final_dict['pred_boxes'])
ret_dict[k]['pred_scores'].append(final_dict['pred_scores'])
ret_dict[k]['pred_labels'].append(final_dict['pred_labels'])
ret_dict[k]['pred_ious'].append(final_dict['pred_ious'])
ret_dict[k]['voxel_ids'].append(final_dict['add_features'])
for k in range(batch_size):
pred_boxes = torch.cat(ret_dict[k]['pred_boxes'], dim=0)
pred_scores = torch.cat(ret_dict[k]['pred_scores'], dim=0)
pred_labels = torch.cat(ret_dict[k]['pred_labels'], dim=0)
voxel_ids = torch.cat(ret_dict[k]['voxel_ids'], dim=0)
ret_dict[k]['pred_boxes'] = pred_boxes
ret_dict[k]['pred_scores'] = pred_scores
ret_dict[k]['pred_labels'] = pred_labels + 1
ret_dict[k]['voxel_ids'] = voxel_ids
return ret_dict
def _get_voxel_infos(self, x):
spatial_shape = x.spatial_shape
voxel_indices = x.indices
spatial_indices = []
num_voxels = []
batch_size = x.batch_size
batch_index = voxel_indices[:, 0]
for bs_idx in range(batch_size):
batch_inds = batch_index==bs_idx
spatial_indices.append(voxel_indices[batch_inds][:, [2, 1]])
num_voxels.append(batch_inds.sum())
return spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels
def forward(self, data_dict):
x = data_dict['encoded_spconv_tensor']
spatial_shape, batch_index, voxel_indices, spatial_indices, num_voxels = self._get_voxel_infos(x)
pred_dicts = []
for idx, head in enumerate(self.heads_list):
pred_dict = head(x)
pred_dicts.append(pred_dict)
pred_dicts = self.generate_predicted_boxes(
data_dict['batch_size'],
pred_dicts, voxel_indices, spatial_shape
)
return pred_dicts
|