# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Tuple, Union import numpy as np from mmdet3d.datasets import PointSample from mmdet3d.registry import TRANSFORMS from mmdet3d.structures.points import BasePoints @TRANSFORMS.register_module() class TR3DPointSample(PointSample): """The only difference with PointSample is the support of float num_points parameter. In this case we sample random fraction of points from num_points to 100% points. These classes should be merged in the future. """ def _points_random_sampling( self, points: BasePoints, num_samples: Union[int, float], sample_range: Optional[float] = None, replace: bool = False, return_choices: bool = False ) -> Union[Tuple[BasePoints, np.ndarray], BasePoints]: """Points random sampling. Sample points to a certain number. Args: points (:obj:`BasePoints`): 3D Points. num_samples (int): Number of samples to be sampled. sample_range (float, optional): Indicating the range where the points will be sampled. Defaults to None. replace (bool): Sampling with or without replacement. Defaults to False. return_choices (bool): Whether return choice. Defaults to False. Returns: tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`: - points (:obj:`BasePoints`): 3D Points. - choices (np.ndarray, optional): The generated random samples. """ if isinstance(num_samples, float): assert num_samples < 1 num_samples = int( np.random.uniform(self.num_points, 1.) * points.shape[0]) if not replace: replace = (points.shape[0] < num_samples) point_range = range(len(points)) if sample_range is not None and not replace: # Only sampling the near points when len(points) >= num_samples dist = np.linalg.norm(points.coord.numpy(), axis=1) far_inds = np.where(dist >= sample_range)[0] near_inds = np.where(dist < sample_range)[0] # in case there are too many far points if len(far_inds) > num_samples: far_inds = np.random.choice( far_inds, num_samples, replace=False) point_range = near_inds num_samples -= len(far_inds) choices = np.random.choice(point_range, num_samples, replace=replace) if sample_range is not None and not replace: choices = np.concatenate((far_inds, choices)) # Shuffle points after sampling np.random.shuffle(choices) if return_choices: return points[choices], choices else: return points[choices]