|
|
|
from typing import List, Sequence, Union |
|
|
|
import mmengine |
|
import numpy as np |
|
import torch |
|
from mmcv import BaseTransform |
|
from mmengine.structures import InstanceData |
|
from numpy import dtype |
|
|
|
from mmdet3d.registry import TRANSFORMS |
|
from mmdet3d.structures import BaseInstance3DBoxes, PointData |
|
from mmdet3d.structures.points import BasePoints |
|
|
|
from .nerf_det3d_data_sample import NeRFDet3DDataSample |
|
|
|
|
|
def to_tensor( |
|
data: Union[torch.Tensor, np.ndarray, Sequence, int, |
|
float]) -> torch.Tensor: |
|
"""Convert objects of various python types to :obj:`torch.Tensor`. |
|
|
|
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, |
|
:class:`Sequence`, :class:`int` and :class:`float`. |
|
|
|
Args: |
|
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to |
|
be converted. |
|
|
|
Returns: |
|
torch.Tensor: the converted data. |
|
""" |
|
|
|
if isinstance(data, torch.Tensor): |
|
return data |
|
elif isinstance(data, np.ndarray): |
|
if data.dtype is dtype('float64'): |
|
data = data.astype(np.float32) |
|
return torch.from_numpy(data) |
|
elif isinstance(data, Sequence) and not mmengine.is_str(data): |
|
return torch.tensor(data) |
|
elif isinstance(data, int): |
|
return torch.LongTensor([data]) |
|
elif isinstance(data, float): |
|
return torch.FloatTensor([data]) |
|
else: |
|
raise TypeError(f'type {type(data)} cannot be converted to tensor.') |
|
|
|
|
|
@TRANSFORMS.register_module() |
|
class PackNeRFDetInputs(BaseTransform): |
|
INPUTS_KEYS = ['points', 'img'] |
|
NERF_INPUT_KEYS = [ |
|
'img', 'denorm_images', 'depth', 'lightpos', 'nerf_sizes', 'raydirs' |
|
] |
|
|
|
INSTANCEDATA_3D_KEYS = [ |
|
'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d' |
|
] |
|
INSTANCEDATA_2D_KEYS = [ |
|
'gt_bboxes', |
|
'gt_bboxes_labels', |
|
] |
|
NERF_3D_KEYS = ['gt_images', 'gt_depths'] |
|
|
|
SEG_KEYS = [ |
|
'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask', |
|
'gt_semantic_seg' |
|
] |
|
|
|
def __init__( |
|
self, |
|
keys: tuple, |
|
meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img', |
|
'depth2img', 'cam2img', 'pad_shape', |
|
'scale_factor', 'flip', 'pcd_horizontal_flip', |
|
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', |
|
'img_norm_cfg', 'num_pts_feats', 'pcd_trans', |
|
'sample_idx', 'pcd_scale_factor', 'pcd_rotation', |
|
'pcd_rotation_angle', 'lidar_path', |
|
'transformation_3d_flow', 'trans_mat', |
|
'affine_aug', 'sweep_img_metas', 'ori_cam2img', |
|
'cam2global', 'crop_offset', 'img_crop_offset', |
|
'resize_img_shape', 'lidar2cam', 'ori_lidar2img', |
|
'num_ref_frames', 'num_views', 'ego2global', |
|
'axis_align_matrix') |
|
) -> None: |
|
self.keys = keys |
|
self.meta_keys = meta_keys |
|
|
|
def _remove_prefix(self, key: str) -> str: |
|
if key.startswith('gt_'): |
|
key = key[3:] |
|
return key |
|
|
|
def transform(self, results: Union[dict, |
|
List[dict]]) -> Union[dict, List[dict]]: |
|
"""Method to pack the input data. when the value in this dict is a |
|
list, it usually is in Augmentations Testing. |
|
|
|
Args: |
|
results (dict | list[dict]): Result dict from the data pipeline. |
|
|
|
Returns: |
|
dict | List[dict]: |
|
|
|
- 'inputs' (dict): The forward data of models. It usually contains |
|
following keys: |
|
|
|
- points |
|
- img |
|
|
|
- 'data_samples' (:obj:`NeRFDet3DDataSample`): The annotation info |
|
of the sample. |
|
""" |
|
|
|
if isinstance(results, list): |
|
if len(results) == 1: |
|
|
|
return self.pack_single_results(results[0]) |
|
pack_results = [] |
|
for single_result in results: |
|
pack_results.append(self.pack_single_results(single_result)) |
|
return pack_results |
|
|
|
elif isinstance(results, dict): |
|
return self.pack_single_results(results) |
|
else: |
|
raise NotImplementedError |
|
|
|
def pack_single_results(self, results: dict) -> dict: |
|
"""Method to pack the single input data. when the value in this dict is |
|
a list, it usually is in Augmentations Testing. |
|
|
|
Args: |
|
results (dict): Result dict from the data pipeline. |
|
|
|
Returns: |
|
dict: A dict contains |
|
|
|
- 'inputs' (dict): The forward data of models. It usually contains |
|
following keys: |
|
|
|
- points |
|
- img |
|
|
|
- 'data_samples' (:obj:`NeRFDet3DDataSample`): The annotation info |
|
of the sample. |
|
""" |
|
|
|
if 'points' in results: |
|
if isinstance(results['points'], BasePoints): |
|
results['points'] = results['points'].tensor |
|
|
|
if 'img' in results: |
|
if isinstance(results['img'], list): |
|
|
|
imgs = np.stack(results['img'], axis=0) |
|
if imgs.flags.c_contiguous: |
|
imgs = to_tensor(imgs).permute(0, 3, 1, 2).contiguous() |
|
else: |
|
imgs = to_tensor( |
|
np.ascontiguousarray(imgs.transpose(0, 3, 1, 2))) |
|
results['img'] = imgs |
|
else: |
|
img = results['img'] |
|
if len(img.shape) < 3: |
|
img = np.expand_dims(img, -1) |
|
|
|
|
|
|
|
|
|
if img.flags.c_contiguous: |
|
img = to_tensor(img).permute(2, 0, 1).contiguous() |
|
else: |
|
img = to_tensor( |
|
np.ascontiguousarray(img.transpose(2, 0, 1))) |
|
results['img'] = img |
|
|
|
if 'depth' in results: |
|
if isinstance(results['depth'], list): |
|
|
|
depth_imgs = np.stack(results['depth'], axis=0) |
|
if depth_imgs.flags.c_contiguous: |
|
depth_imgs = to_tensor(depth_imgs).contiguous() |
|
else: |
|
depth_imgs = to_tensor(np.ascontiguousarray(depth_imgs)) |
|
results['depth'] = depth_imgs |
|
else: |
|
depth_img = results['depth'] |
|
if len(depth_img.shape) < 3: |
|
depth_img = np.expand_dims(depth_img, -1) |
|
if depth_img.flags.c_contiguous: |
|
depth_img = to_tensor(depth_img).contiguous() |
|
else: |
|
depth_img = to_tensor(np.ascontiguousarray(depth_img)) |
|
results['depth'] = depth_img |
|
|
|
if 'ray_info' in results: |
|
if isinstance(results['raydirs'], list): |
|
raydirs = np.stack(results['raydirs'], axis=0) |
|
if raydirs.flags.c_contiguous: |
|
raydirs = to_tensor(raydirs).contiguous() |
|
else: |
|
raydirs = to_tensor(np.ascontiguousarray(raydirs)) |
|
results['raydirs'] = raydirs |
|
|
|
if isinstance(results['lightpos'], list): |
|
lightposes = np.stack(results['lightpos'], axis=0) |
|
if lightposes.flags.c_contiguous: |
|
lightposes = to_tensor(lightposes).contiguous() |
|
else: |
|
lightposes = to_tensor(np.ascontiguousarray(lightposes)) |
|
lightposes = lightposes.unsqueeze(1).repeat( |
|
1, raydirs.shape[1], 1) |
|
results['lightpos'] = lightposes |
|
|
|
if isinstance(results['gt_images'], list): |
|
gt_images = np.stack(results['gt_images'], axis=0) |
|
if gt_images.flags.c_contiguous: |
|
gt_images = to_tensor(gt_images).contiguous() |
|
else: |
|
gt_images = to_tensor(np.ascontiguousarray(gt_images)) |
|
results['gt_images'] = gt_images |
|
|
|
if isinstance(results['gt_depths'], |
|
list) and len(results['gt_depths']) != 0: |
|
gt_depths = np.stack(results['gt_depths'], axis=0) |
|
if gt_depths.flags.c_contiguous: |
|
gt_depths = to_tensor(gt_depths).contiguous() |
|
else: |
|
gt_depths = to_tensor(np.ascontiguousarray(gt_depths)) |
|
results['gt_depths'] = gt_depths |
|
|
|
if isinstance(results['denorm_images'], list): |
|
denorm_imgs = np.stack(results['denorm_images'], axis=0) |
|
if denorm_imgs.flags.c_contiguous: |
|
denorm_imgs = to_tensor(denorm_imgs).permute( |
|
0, 3, 1, 2).contiguous() |
|
else: |
|
denorm_imgs = to_tensor( |
|
np.ascontiguousarray( |
|
denorm_imgs.transpose(0, 3, 1, 2))) |
|
results['denorm_images'] = denorm_imgs |
|
|
|
for key in [ |
|
'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', |
|
'gt_bboxes_labels', 'attr_labels', 'pts_instance_mask', |
|
'pts_semantic_mask', 'centers_2d', 'depths', 'gt_labels_3d' |
|
]: |
|
if key not in results: |
|
continue |
|
if isinstance(results[key], list): |
|
results[key] = [to_tensor(res) for res in results[key]] |
|
else: |
|
results[key] = to_tensor(results[key]) |
|
if 'gt_bboxes_3d' in results: |
|
if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes): |
|
results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d']) |
|
|
|
if 'gt_semantic_seg' in results: |
|
results['gt_semantic_seg'] = to_tensor( |
|
results['gt_semantic_seg'][None]) |
|
if 'gt_seg_map' in results: |
|
results['gt_seg_map'] = results['gt_seg_map'][None, ...] |
|
|
|
if 'gt_images' in results: |
|
results['gt_images'] = to_tensor(results['gt_images']) |
|
if 'gt_depths' in results: |
|
results['gt_depths'] = to_tensor(results['gt_depths']) |
|
|
|
data_sample = NeRFDet3DDataSample() |
|
gt_instances_3d = InstanceData() |
|
gt_instances = InstanceData() |
|
gt_pts_seg = PointData() |
|
gt_nerf_images = InstanceData() |
|
gt_nerf_depths = InstanceData() |
|
|
|
data_metas = {} |
|
for key in self.meta_keys: |
|
if key in results: |
|
data_metas[key] = results[key] |
|
elif 'images' in results: |
|
if len(results['images'].keys()) == 1: |
|
cam_type = list(results['images'].keys())[0] |
|
|
|
if key in results['images'][cam_type]: |
|
data_metas[key] = results['images'][cam_type][key] |
|
else: |
|
|
|
img_metas = [] |
|
cam_types = list(results['images'].keys()) |
|
for cam_type in cam_types: |
|
if key in results['images'][cam_type]: |
|
img_metas.append(results['images'][cam_type][key]) |
|
if len(img_metas) > 0: |
|
data_metas[key] = img_metas |
|
elif 'lidar_points' in results: |
|
if key in results['lidar_points']: |
|
data_metas[key] = results['lidar_points'][key] |
|
data_sample.set_metainfo(data_metas) |
|
|
|
inputs = {} |
|
for key in self.keys: |
|
if key in results: |
|
|
|
if key in self.NERF_INPUT_KEYS: |
|
inputs[key] = results[key] |
|
elif key in self.NERF_3D_KEYS: |
|
if key == 'gt_images': |
|
gt_nerf_images[self._remove_prefix(key)] = results[key] |
|
else: |
|
gt_nerf_depths[self._remove_prefix(key)] = results[key] |
|
elif key in self.INSTANCEDATA_3D_KEYS: |
|
gt_instances_3d[self._remove_prefix(key)] = results[key] |
|
elif key in self.INSTANCEDATA_2D_KEYS: |
|
if key == 'gt_bboxes_labels': |
|
gt_instances['labels'] = results[key] |
|
else: |
|
gt_instances[self._remove_prefix(key)] = results[key] |
|
elif key in self.SEG_KEYS: |
|
gt_pts_seg[self._remove_prefix(key)] = results[key] |
|
else: |
|
raise NotImplementedError(f'Please modified ' |
|
f'`Pack3DDetInputs` ' |
|
f'to put {key} to ' |
|
f'corresponding field') |
|
|
|
data_sample.gt_instances_3d = gt_instances_3d |
|
data_sample.gt_instances = gt_instances |
|
data_sample.gt_pts_seg = gt_pts_seg |
|
data_sample.gt_nerf_images = gt_nerf_images |
|
data_sample.gt_nerf_depths = gt_nerf_depths |
|
if 'eval_ann_info' in results: |
|
data_sample.eval_ann_info = results['eval_ann_info'] |
|
else: |
|
data_sample.eval_ann_info = None |
|
|
|
packed_results = dict() |
|
packed_results['data_samples'] = data_sample |
|
packed_results['inputs'] = inputs |
|
|
|
return packed_results |
|
|
|
def __repr__(self) -> str: |
|
"""str: Return a string that describes the module.""" |
|
repr_str = self.__class__.__name__ |
|
repr_str += f'(keys={self.keys})' |
|
repr_str += f'(meta_keys={self.meta_keys})' |
|
return repr_str |
|
|