mart9992's picture
m
b793f0c
raw
history blame
6.39 kB
import numpy as np
import torch
import torch.nn as nn
from .models.data_processor import DataProcessor
from .models.mean_vfe import MeanVFE
from .models.spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
from .models.voxelnext_head import VoxelNeXtHead
from .utils.image_projection import _proj_voxel_image
from segment_anything import SamPredictor, sam_model_registry
class VoxelNeXt(nn.Module):
def __init__(self, model_cfg):
super().__init__()
point_cloud_range = np.array(model_cfg.POINT_CLOUD_RANGE, dtype=np.float32)
self.data_processor = DataProcessor(
model_cfg.DATA_PROCESSOR, point_cloud_range=point_cloud_range,
training=False, num_point_features=len(model_cfg.USED_FEATURE_LIST)
)
input_channels = model_cfg.get('INPUT_CHANNELS', 5)
grid_size = np.array(model_cfg.get('GRID_SIZE', [1440, 1440, 40]))
class_names = model_cfg.get('CLASS_NAMES')
kernel_size_head = model_cfg.get('KERNEL_SIZE_HEAD', 1)
self.point_cloud_range = torch.Tensor(model_cfg.get('POINT_CLOUD_RANGE', [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]))
self.voxel_size = torch.Tensor(model_cfg.get('VOXEL_SIZE', [0.075, 0.075, 0.2]))
CLASS_NAMES_EACH_HEAD = model_cfg.get('CLASS_NAMES_EACH_HEAD')
SEPARATE_HEAD_CFG = model_cfg.get('SEPARATE_HEAD_CFG')
POST_PROCESSING = model_cfg.get('POST_PROCESSING')
self.voxelization = MeanVFE()
self.backbone_3d = VoxelResBackBone8xVoxelNeXt(input_channels, grid_size)
self.dense_head = VoxelNeXtHead(class_names, self.point_cloud_range, self.voxel_size, kernel_size_head,
CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING)
class Model(nn.Module):
def __init__(self, model_cfg, device="cuda"):
super().__init__()
sam_type = model_cfg.get('SAM_TYPE', "vit_b")
sam_checkpoint = model_cfg.get('SAM_CHECKPOINT', "/data/sam_vit_b_01ec64.pth")
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint).to(device=device)
self.sam_predictor = SamPredictor(sam)
voxelnext_checkpoint = model_cfg.get('VOXELNEXT_CHECKPOINT', "/data/voxelnext_nuscenes_kernel1.pth")
model_dict = torch.load(voxelnext_checkpoint)
self.voxelnext = VoxelNeXt(model_cfg).to(device=device)
self.voxelnext.load_state_dict(model_dict)
self.point_features = {}
self.device = device
def image_embedding(self, image):
self.sam_predictor.set_image(image)
def point_embedding(self, data_dict, image_id):
data_dict = self.voxelnext.data_processor.forward(
data_dict=data_dict
)
data_dict['voxels'] = torch.Tensor(data_dict['voxels']).to(self.device)
data_dict['voxel_num_points'] = torch.Tensor(data_dict['voxel_num_points']).to(self.device)
data_dict['voxel_coords'] = torch.Tensor(data_dict['voxel_coords']).to(self.device)
data_dict = self.voxelnext.voxelization(data_dict)
n_voxels = data_dict['voxel_coords'].shape[0]
device = data_dict['voxel_coords'].device
dtype = data_dict['voxel_coords'].dtype
data_dict['voxel_coords'] = torch.cat([torch.zeros((n_voxels, 1), device=device, dtype=dtype), data_dict['voxel_coords']], dim=1)
data_dict['batch_size'] = 1
if not image_id in self.point_features:
data_dict = self.voxelnext.backbone_3d(data_dict)
self.point_features[image_id] = data_dict
else:
data_dict = self.point_features[image_id]
pred_dicts = self.voxelnext.dense_head(data_dict)
voxel_coords = data_dict['out_voxels'][pred_dicts[0]['voxel_ids'].squeeze(-1)] * self.voxelnext.dense_head.feature_map_stride
return pred_dicts, voxel_coords
def generate_3D_box(self, lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=0.1):
device = voxel_coords.device
points_image, depth = _proj_voxel_image(voxel_coords, lidar2img_rt, self.voxelnext.voxel_size.to(device), self.voxelnext.point_cloud_range.to(device))
points = points_image.permute(1, 0).int().cpu().numpy()
selected_voxels = torch.zeros_like(depth).squeeze(0)
for i in range(points.shape[0]):
point = points[i]
if point[0] < 0 or point[1] < 0 or point[0] >= mask.shape[1] or point[1] >= mask.shape[0]:
continue
if mask[point[1], point[0]]:
selected_voxels[i] = 1
mask_extra = (pred_dicts[0]['pred_scores'] > quality_score)
if mask_extra.sum() == 0:
print("no high quality 3D box related.")
return None
selected_voxels *= mask_extra
if selected_voxels.sum() > 0:
selected_box_id = pred_dicts[0]['pred_scores'][selected_voxels.bool()].argmax()
selected_box = pred_dicts[0]['pred_boxes'][selected_voxels.bool()][selected_box_id]
else:
grid_x, grid_y = torch.meshgrid(torch.arange(mask.shape[0]), torch.arange(mask.shape[1]))
mask_x, mask_y = grid_x[mask], grid_y[mask]
mask_center = torch.Tensor([mask_y.float().mean(), mask_x.float().mean()]).to(
pred_dicts[0]['pred_boxes'].device).unsqueeze(1)
dist = ((points_image - mask_center) ** 2).sum(0)
selected_id = dist[mask_extra].argmin()
selected_box = pred_dicts[0]['pred_boxes'][mask_extra][selected_id]
return selected_box
def forward(self, image, point_dict, prompt_point, lidar2img_rt, image_id, quality_score=0.1):
self.image_embedding(image)
pred_dicts, voxel_coords = self.point_embedding(point_dict, image_id)
masks, scores, _ = self.sam_predictor.predict(point_coords=prompt_point, point_labels=np.array([1]))
mask = masks[0]
box3d = self.generate_3D_box(lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=quality_score)
return mask, box3d
if __name__ == '__main__':
cfg_dataset = 'nuscenes_dataset.yaml'
cfg_model = 'config.yaml'
dataset_cfg = cfg_from_yaml_file(cfg_dataset, cfg)
model_cfg = cfg_from_yaml_file(cfg_model, cfg)
nuscenes_dataset = NuScenesDataset(dataset_cfg)
model = Model(model_cfg)
index = 0
data_dict = nuscenes_dataset._get_points(index)
model.point_embedding(data_dict)