# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from mmdet3d.models.detectors import Base3DDetector from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.utils import ConfigType, OptConfigType from .nerf_utils.nerf_mlp import VanillaNeRF from .nerf_utils.projection import Projector from .nerf_utils.render_ray import render_rays # from ..utils.nerf_utils.save_rendered_img import save_rendered_img @MODELS.register_module() class NerfDet(Base3DDetector): r"""`ImVoxelNet `_. Args: backbone (:obj:`ConfigDict` or dict): The backbone config. neck (:obj:`ConfigDict` or dict): The neck config. neck_3d(:obj:`ConfigDict` or dict): The 3D neck config. bbox_head(:obj:`ConfigDict` or dict): The bbox head config. prior_generator (:obj:`ConfigDict` or dict): The prior generator config. n_voxels (list): Number of voxels along x, y, z axis. voxel_size (list): The size of voxels.Each voxel represents a cube of `voxel_size[0]` meters, `voxel_size[1]` meters, `` train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of training hyper-parameters. Defaults to None. test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test hyper-parameters. Defaults to None. init_cfg (:obj:`ConfigDict` or dict, optional): The initialization config. Defaults to None. render_testing (bool): If you want to render novel view, please set "render_testing = True" in config The other args are the parameters of NeRF, you can just use the default values. """ def __init__( self, backbone: ConfigType, neck: ConfigType, neck_3d: ConfigType, bbox_head: ConfigType, prior_generator: ConfigType, n_voxels: List, voxel_size: List, head_2d: ConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptConfigType = None, # pretrained, aabb: Tuple = None, near_far_range: List = None, N_samples: int = 64, N_rand: int = 2048, depth_supervise: bool = False, use_nerf_mask: bool = True, nerf_sample_view: int = 3, nerf_mode: str = 'volume', squeeze_scale: int = 4, rgb_supervision: bool = True, nerf_density: bool = False, render_testing: bool = False): super().__init__( data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.backbone = MODELS.build(backbone) self.neck = MODELS.build(neck) self.neck_3d = MODELS.build(neck_3d) bbox_head.update(train_cfg=train_cfg) bbox_head.update(test_cfg=test_cfg) self.bbox_head = MODELS.build(bbox_head) self.head_2d = MODELS.build(head_2d) if head_2d is not None else None self.n_voxels = n_voxels self.prior_generator = TASK_UTILS.build(prior_generator) self.voxel_size = voxel_size self.train_cfg = train_cfg self.test_cfg = test_cfg self.aabb = aabb self.near_far_range = near_far_range self.N_samples = N_samples self.N_rand = N_rand self.depth_supervise = depth_supervise self.projector = Projector() self.squeeze_scale = squeeze_scale self.use_nerf_mask = use_nerf_mask self.rgb_supervision = rgb_supervision nerf_feature_dim = neck['out_channels'] // squeeze_scale self.nerf_mlp = VanillaNeRF( net_depth=4, # The depth of the MLP net_width=256, # The width of the MLP skip_layer=3, # The layer to add skip layers to. feature_dim=nerf_feature_dim + 6, # + RGB original imgs net_depth_condition=1, # The depth of the second part of MLP net_width_condition=128) self.nerf_mode = nerf_mode self.nerf_density = nerf_density self.nerf_sample_view = nerf_sample_view self.render_testing = render_testing # hard code here, will deal with batch issue later. self.cov = nn.Sequential( nn.Conv3d( neck['out_channels'], neck['out_channels'], kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv3d( neck['out_channels'], neck['out_channels'], kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(neck['out_channels'], 1, kernel_size=1)) self.mean_mapping = nn.Sequential( nn.Conv3d( neck['out_channels'], nerf_feature_dim // 2, kernel_size=1)) self.cov_mapping = nn.Sequential( nn.Conv3d( neck['out_channels'], nerf_feature_dim // 2, kernel_size=1)) self.mapping = nn.Sequential( nn.Linear(neck['out_channels'], nerf_feature_dim // 2)) self.mapping_2d = nn.Sequential( nn.Conv2d( neck['out_channels'], nerf_feature_dim // 2, kernel_size=1)) # self.overfit_nerfmlp = overfit_nerfmlp # if self.overfit_nerfmlp: # self. _finetuning_NeRF_MLP() self.render_testing = render_testing def extract_feat(self, batch_inputs_dict: dict, batch_data_samples: SampleList, mode, depth=None, ray_batch=None): """Extract 3d features from the backbone -> fpn -> 3d projection. -> 3d neck -> bbox_head. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instances` of `gt_panoptic_seg` or `gt_sem_seg` Returns: Tuple: - torch.Tensor: Features of shape (N, C_out, N_x, N_y, N_z). - torch.Tensor: Valid mask of shape (N, 1, N_x, N_y, N_z). - torch.Tensor: 2D features if needed. - dict: The nerf rendered information including the 'output_coarse', 'gt_rgb' and 'gt_depth' keys. """ img = batch_inputs_dict['imgs'] img = img.float() batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] batch_size = img.shape[0] if len(img.shape) > 4: img = img.reshape([-1] + list(img.shape)[2:]) x = self.backbone(img) x = self.neck(x)[0] x = x.reshape([batch_size, -1] + list(x.shape[1:])) else: x = self.backbone(img) x = self.neck(x)[0] if depth is not None: depth_bs = depth.shape[0] assert depth_bs == batch_size depth = batch_inputs_dict['depth'] depth = depth.reshape([-1] + list(depth.shape)[2:]) features_2d = self.head_2d.forward(x[-1], batch_img_metas) \ if self.head_2d is not None else None stride = img.shape[-1] / x.shape[-1] assert stride == 4 stride = int(stride) volumes, valids = [], [] rgb_preds = [] for feature, img_meta in zip(x, batch_img_metas): angles = features_2d[ 0] if features_2d is not None and mode == 'test' else None projection = self._compute_projection(img_meta, stride, angles).to(x.device) points = get_points( n_voxels=torch.tensor(self.n_voxels), voxel_size=torch.tensor(self.voxel_size), origin=torch.tensor(img_meta['lidar2img']['origin'])).to( x.device) height = img_meta['img_shape'][0] // stride width = img_meta['img_shape'][1] // stride # Construct the volume space # volume together with valid is the constructed scene # volume represents V_i and valid represents M_p volume, valid = backproject(feature[:, :, :height, :width], points, projection, depth, self.voxel_size) density = None volume_sum = volume.sum(dim=0) # cov_valid = valid.clone().detach() valid = valid.sum(dim=0) volume_mean = volume_sum / (valid + 1e-8) volume_mean[:, valid[0] == 0] = .0 # volume_cov = (volume - volume_mean.unsqueeze(0)) ** 2 * cov_valid # volume_cov = torch.sum(volume_cov, dim=0) / (valid + 1e-8) volume_cov = torch.sum( (volume - volume_mean.unsqueeze(0))**2, dim=0) / ( valid + 1e-8) volume_cov[:, valid[0] == 0] = 1e6 volume_cov = torch.exp(-volume_cov) # default setting # be careful here, the smaller the cov, the larger the weight. n_channels, n_x_voxels, n_y_voxels, n_z_voxels = volume_mean.shape if ray_batch is not None: if self.nerf_mode == 'volume': mean_volume = self.mean_mapping(volume_mean.unsqueeze(0)) cov_volume = self.cov_mapping(volume_cov.unsqueeze(0)) feature_2d = feature[:, :, :height, :width] elif self.nerf_mode == 'image': mean_volume = None cov_volume = None feature_2d = feature[:, :, :height, :width] n_v, C, height, width = feature_2d.shape feature_2d = feature_2d.view(n_v, C, -1).permute(0, 2, 1).contiguous() feature_2d = self.mapping(feature_2d).permute( 0, 2, 1).contiguous().view(n_v, -1, height, width) denorm_images = ray_batch['denorm_images'] denorm_images = denorm_images.reshape( [-1] + list(denorm_images.shape)[2:]) rgb_projection = self._compute_projection( img_meta, stride=1, angles=None).to(x.device) rgb_volume, _ = backproject( denorm_images[:, :, :img_meta['img_shape'][0], : img_meta['img_shape'][1]], points, rgb_projection, depth, self.voxel_size) ret = render_rays( ray_batch, mean_volume, cov_volume, feature_2d, denorm_images, self.aabb, self.near_far_range, self.N_samples, self.N_rand, self.nerf_mlp, img_meta, self.projector, self.nerf_mode, self.nerf_sample_view, is_train=True if mode == 'train' else False, render_testing=self.render_testing) rgb_preds.append(ret) if self.nerf_density: # would have 0 bias issue for mean_mapping. n_v, C, n_x_voxels, n_y_voxels, n_z_voxels = volume.shape volume = volume.view(n_v, C, -1).permute(0, 2, 1).contiguous() mapping_volume = self.mapping(volume).permute( 0, 2, 1).contiguous().view(n_v, -1, n_x_voxels, n_y_voxels, n_z_voxels) mapping_volume = torch.cat([rgb_volume, mapping_volume], dim=1) mapping_volume_sum = mapping_volume.sum(dim=0) mapping_volume_mean = mapping_volume_sum / (valid + 1e-8) # mapping_volume_cov = ( # mapping_volume - mapping_volume_mean.unsqueeze(0) # ) ** 2 * cov_valid mapping_volume_cov = (mapping_volume - mapping_volume_mean.unsqueeze(0))**2 mapping_volume_cov = torch.sum( mapping_volume_cov, dim=0) / ( valid + 1e-8) mapping_volume_cov[:, valid[0] == 0] = 1e6 mapping_volume_cov = torch.exp( -mapping_volume_cov) # default setting global_volume = torch.cat( [mapping_volume_mean, mapping_volume_cov], dim=1) global_volume = global_volume.view( -1, n_x_voxels * n_y_voxels * n_z_voxels).permute( 1, 0).contiguous() points = points.view(3, -1).permute(1, 0).contiguous() density = self.nerf_mlp.query_density( points, global_volume) alpha = 1 - torch.exp(-density) # density -> alpha # (1, n_x_voxels, n_y_voxels, n_z_voxels) volume = alpha.view(1, n_x_voxels, n_y_voxels, n_z_voxels) * volume_mean volume[:, valid[0] == 0] = .0 volumes.append(volume) valids.append(valid) x = torch.stack(volumes) x = self.neck_3d(x) return x, torch.stack(valids).float(), features_2d, rgb_preds def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, **kwargs) -> Union[dict, list]: """Calculate losses from a batch of inputs and data samples. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (list[:obj: `DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ ray_batchs = {} batch_images = [] batch_depths = [] if 'images' in batch_data_samples[0].gt_nerf_images: for data_samples in batch_data_samples: image = data_samples.gt_nerf_images['images'] batch_images.append(image) batch_images = torch.stack(batch_images) if 'depths' in batch_data_samples[0].gt_nerf_depths: for data_samples in batch_data_samples: depth = data_samples.gt_nerf_depths['depths'] batch_depths.append(depth) batch_depths = torch.stack(batch_depths) if 'raydirs' in batch_inputs_dict.keys(): ray_batchs['ray_o'] = batch_inputs_dict['lightpos'] ray_batchs['ray_d'] = batch_inputs_dict['raydirs'] ray_batchs['gt_rgb'] = batch_images ray_batchs['gt_depth'] = batch_depths ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes'] ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images'] x, valids, features_2d, rgb_preds = self.extract_feat( batch_inputs_dict, batch_data_samples, 'train', depth=None, ray_batch=ray_batchs) else: x, valids, features_2d, rgb_preds = self.extract_feat( batch_inputs_dict, batch_data_samples, 'train') x += (valids, ) losses = self.bbox_head.loss(x, batch_data_samples, **kwargs) # if self.head_2d is not None: # losses.update( # self.head_2d.loss(*features_2d, batch_data_samples) # ) if len(ray_batchs) != 0 and self.rgb_supervision: losses.update(self.nvs_loss_func(rgb_preds)) if self.depth_supervise: losses.update(self.depth_loss_func(rgb_preds)) return losses def nvs_loss_func(self, rgb_pred): loss = 0 for ret in rgb_pred: rgb = ret['outputs_coarse']['rgb'] gt = ret['gt_rgb'] masks = ret['outputs_coarse']['mask'] if self.use_nerf_mask: loss += torch.sum(masks.unsqueeze(-1) * (rgb - gt)**2) / ( masks.sum() + 1e-6) else: loss += torch.mean((rgb - gt)**2) return dict(loss_nvs=loss) def depth_loss_func(self, rgb_pred): loss = 0 for ret in rgb_pred: depth = ret['outputs_coarse']['depth'] gt = ret['gt_depth'].squeeze(-1) masks = ret['outputs_coarse']['mask'] if self.use_nerf_mask: loss += torch.sum(masks * torch.abs(depth - gt)) / ( masks.sum() + 1e-6) else: loss += torch.mean(torch.abs(depth - gt)) return dict(loss_depth=loss) def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList, **kwargs) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data Samples. It usually includes information such as `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. Returns: list[:obj:`NeRFDet3DDataSample`]: Detection results of the input images. Each NeRFDet3DDataSample usually contain 'pred_instances_3d'. And the ``pred_instances_3d`` usually contains following keys. - scores_3d (Tensor): Classification scores, has a shape (num_instance, ) - labels_3d (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes_3d (Tensor): Contains a tensor with shape (num_instances, C) where C = 6. """ ray_batchs = {} batch_images = [] batch_depths = [] if 'images' in batch_data_samples[0].gt_nerf_images: for data_samples in batch_data_samples: image = data_samples.gt_nerf_images['images'] batch_images.append(image) batch_images = torch.stack(batch_images) if 'depths' in batch_data_samples[0].gt_nerf_depths: for data_samples in batch_data_samples: depth = data_samples.gt_nerf_depths['depths'] batch_depths.append(depth) batch_depths = torch.stack(batch_depths) if 'raydirs' in batch_inputs_dict.keys(): ray_batchs['ray_o'] = batch_inputs_dict['lightpos'] ray_batchs['ray_d'] = batch_inputs_dict['raydirs'] ray_batchs['gt_rgb'] = batch_images ray_batchs['gt_depth'] = batch_depths ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes'] ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images'] x, valids, features_2d, rgb_preds = self.extract_feat( batch_inputs_dict, batch_data_samples, 'test', depth=None, ray_batch=ray_batchs) else: x, valids, features_2d, rgb_preds = self.extract_feat( batch_inputs_dict, batch_data_samples, 'test') x += (valids, ) results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs) predictions = self.add_pred_to_datasample(batch_data_samples, results_list) return predictions def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList, *args, **kwargs) -> Tuple[List[torch.Tensor]]: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. Args: batch_inputs_dict (dict): The model input dict which include the 'imgs' key. - imgs (torch.Tensor, optional): Image of each sample. batch_data_samples (List[:obj:`NeRFDet3DDataSample`]): The Data Samples. It usually includes information such as `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d` Returns: tuple[list]: A tuple of features from ``bbox_head`` forward """ ray_batchs = {} batch_images = [] batch_depths = [] if 'images' in batch_data_samples[0].gt_nerf_images: for data_samples in batch_data_samples: image = data_samples.gt_nerf_images['images'] batch_images.append(image) batch_images = torch.stack(batch_images) if 'depths' in batch_data_samples[0].gt_nerf_depths: for data_samples in batch_data_samples: depth = data_samples.gt_nerf_depths['depths'] batch_depths.append(depth) batch_depths = torch.stack(batch_depths) if 'raydirs' in batch_inputs_dict.keys(): ray_batchs['ray_o'] = batch_inputs_dict['lightpos'] ray_batchs['ray_d'] = batch_inputs_dict['raydirs'] ray_batchs['gt_rgb'] = batch_images ray_batchs['gt_depth'] = batch_depths ray_batchs['nerf_sizes'] = batch_inputs_dict['nerf_sizes'] ray_batchs['denorm_images'] = batch_inputs_dict['denorm_images'] x, valids, features_2d, rgb_preds = self.extract_feat( batch_inputs_dict, batch_data_samples, 'train', depth=None, ray_batch=ray_batchs) else: x, valids, features_2d, rgb_preds = self.extract_feat( batch_inputs_dict, batch_data_samples, 'train') x += (valids, ) results = self.bbox_head.forward(x) return results def aug_test(self, batch_inputs_dict, batch_data_samples): pass def show_results(self, *args, **kwargs): pass @staticmethod def _compute_projection(img_meta, stride, angles): projection = [] intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:3, :3]) ratio = img_meta['ori_shape'][0] / (img_meta['img_shape'][0] / stride) intrinsic[:2] /= ratio # use predict pitch and roll for SUNRGBDTotal test if angles is not None: extrinsics = [] for angle in angles: extrinsics.append(get_extrinsics(angle).to(intrinsic.device)) else: extrinsics = map(torch.tensor, img_meta['lidar2img']['extrinsic']) for extrinsic in extrinsics: projection.append(intrinsic @ extrinsic[:3]) return torch.stack(projection) @torch.no_grad() def get_points(n_voxels, voxel_size, origin): # origin: point-cloud center. points = torch.stack( torch.meshgrid([ torch.arange(n_voxels[0]), # 40 W width, x torch.arange(n_voxels[1]), # 40 D depth, y torch.arange(n_voxels[2]) # 16 H Height, z ])) new_origin = origin - n_voxels / 2. * voxel_size points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1) return points # modify from https://github.com/magicleap/Atlas/blob/master/atlas/model.py def backproject(features, points, projection, depth, voxel_size): n_images, n_channels, height, width = features.shape n_x_voxels, n_y_voxels, n_z_voxels = points.shape[-3:] points = points.view(1, 3, -1).expand(n_images, 3, -1) points = torch.cat((points, torch.ones_like(points[:, :1])), dim=1) points_2d_3 = torch.bmm(projection, points) x = (points_2d_3[:, 0] / points_2d_3[:, 2]).round().long() y = (points_2d_3[:, 1] / points_2d_3[:, 2]).round().long() z = points_2d_3[:, 2] valid = (x >= 0) & (y >= 0) & (x < width) & (y < height) & (z > 0) # below is using depth to sample feature if depth is not None: depth = F.interpolate( depth.unsqueeze(1), size=(height, width), mode='bilinear').squeeze(1) for i in range(n_images): z_mask = z.clone() > 0 z_mask[i, valid[i]] = \ (z[i, valid[i]] > depth[i, y[i, valid[i]], x[i, valid[i]]] - voxel_size[-1]) & \ (z[i, valid[i]] < depth[i, y[i, valid[i]], x[i, valid[i]]] + voxel_size[-1]) # noqa valid = valid & z_mask volume = torch.zeros((n_images, n_channels, points.shape[-1]), device=features.device) for i in range(n_images): volume[i, :, valid[i]] = features[i, :, y[i, valid[i]], x[i, valid[i]]] volume = volume.view(n_images, n_channels, n_x_voxels, n_y_voxels, n_z_voxels) valid = valid.view(n_images, 1, n_x_voxels, n_y_voxels, n_z_voxels) return volume, valid # for SUNRGBDTotal test def get_extrinsics(angles): yaw = angles.new_zeros(()) pitch, roll = angles r = angles.new_zeros((3, 3)) r[0, 0] = torch.cos(yaw) * torch.cos(pitch) r[0, 1] = torch.sin(yaw) * torch.sin(roll) - torch.cos(yaw) * torch.cos( roll) * torch.sin(pitch) r[0, 2] = torch.cos(roll) * torch.sin(yaw) + torch.cos(yaw) * torch.sin( pitch) * torch.sin(roll) r[1, 0] = torch.sin(pitch) r[1, 1] = torch.cos(pitch) * torch.cos(roll) r[1, 2] = -torch.cos(pitch) * torch.sin(roll) r[2, 0] = -torch.cos(pitch) * torch.sin(yaw) r[2, 1] = torch.cos(yaw) * torch.sin(roll) + torch.cos(roll) * torch.sin( yaw) * torch.sin(pitch) r[2, 2] = torch.cos(yaw) * torch.cos(roll) - torch.sin(yaw) * torch.sin( pitch) * torch.sin(roll) # follow Total3DUnderstanding t = angles.new_tensor([[0., 0., 1.], [0., -1., 0.], [-1., 0., 0.]]) r = t @ r.T # follow DepthInstance3DBoxes r = r[:, [2, 0, 1]] r[2] *= -1 extrinsic = angles.new_zeros((4, 4)) extrinsic[:3, :3] = r extrinsic[3, 3] = 1. return extrinsic