# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List from torch import Tensor from mmdet3d.models import EncoderDecoder3D from mmdet3d.registry import MODELS from mmdet3d.structures import PointData from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList @MODELS.register_module() class RangeImageSegmentor(EncoderDecoder3D): def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList) -> Dict[str, Tensor]: """Calculate losses from a batch of inputs and data samples. Args: batch_inputs_dict (dict): Input sample dict which includes 'points' and 'imgs' keys. - points (List[Tensor]): Point cloud of each sample. - imgs (Tensor, optional): Image tensor has shape (B, C, H, W). batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data samples. It usually includes information such as `metainfo` and `gt_pts_seg`. Returns: Dict[str, Tensor]: A dictionary of loss components. """ # extract features using backbone imgs = batch_inputs_dict['imgs'] x = self.extract_feat(imgs) losses = dict() loss_decode = self._decode_head_forward_train(x, batch_data_samples) losses.update(loss_decode) if self.with_auxiliary_head: loss_aux = self._auxiliary_head_forward_train( x, batch_data_samples) losses.update(loss_aux) return losses def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList, rescale: bool = True) -> SampleList: """Simple test with single scene. Args: batch_inputs_dict (dict): Input sample dict which includes 'points' and 'imgs' keys. - points (List[Tensor]): Point cloud of each sample. - imgs (Tensor, optional): Image tensor has shape (B, C, H, W). batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data samples. It usually includes information such as `metainfo` and `gt_pts_seg`. rescale (bool): Whether transform to original number of points. Will be used for voxelization based segmentors. Defaults to True. Returns: List[:obj:`Det3DDataSample`]: Segmentation results of the input points. Each Det3DDataSample usually contains: - ``pred_pts_seg`` (PointData): Prediction of 3D semantic segmentation. - ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic segmentation before normalization. """ # 3D segmentation requires per-point prediction, so it's impossible # to use down-sampling to get a batch of scenes with same num_points # therefore, we only support testing one scene every time batch_input_metas = [] for data_sample in batch_data_samples: batch_input_metas.append(data_sample.metainfo) imgs = batch_inputs_dict['imgs'] x = self.extract_feat(imgs) seg_labels_list = self.decode_head.predict(x, batch_input_metas, self.test_cfg) return self.postprocess_result(seg_labels_list, batch_data_samples) def _forward(self, batch_inputs_dict: dict, batch_data_samples: OptSampleList = None) -> Tensor: """Network forward process. Args: batch_inputs_dict (dict): Input sample dict which includes 'points' and 'imgs' keys. - points (List[Tensor]): Point cloud of each sample. - imgs (Tensor, optional): Image tensor has shape (B, C, H, W). batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data samples. It usually includes information such as `metainfo` and `gt_pts_seg`. Returns: Tensor: Forward output of model without any post-processes. """ imgs = batch_inputs_dict['imgs'] x = self.extract_feat(imgs) return self.decode_head.forward(x) def postprocess_result(self, seg_labels_list: List[Tensor], batch_data_samples: SampleList) -> SampleList: """Convert results list to `Det3DDataSample`. Args: seg_labels_list (List[Tensor]): List of segmentation results, seg_logits from model of each input point clouds sample. batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data samples. It usually includes information such as `metainfo` and `gt_pts_seg`. Returns: List[:obj:`Det3DDataSample`]: Segmentation results of the input points. Each Det3DDataSample usually contains: - ``pred_pts_seg`` (PointData): Prediction of 3D semantic segmentation. - ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic segmentation before normalization. """ for i, seg_pred in enumerate(seg_labels_list): batch_data_samples[i].set_data( {'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})}) return batch_data_samples