giantmonkeyTC
mm2
c2ca15f
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from mmcv.cnn import ConvModule
from mmcv.ops.group_points import GroupAll, QueryAndGroup, grouping_operation
from torch import Tensor
from torch import nn as nn
from torch.nn import functional as F
from mmdet3d.utils import ConfigType
class BaseDGCNNGFModule(nn.Module):
"""Base module for point graph feature module used in DGCNN.
Args:
radii (List[float]): List of radius in each knn or ball query.
sample_nums (List[int]): Number of samples in each knn or ball query.
mlp_channels (List[List[int]]): Specify of the dgcnn before the global
pooling for each graph feature module.
knn_modes (List[str]): Type of KNN method, valid mode
['F-KNN', 'D-KNN']. Defaults to ['F-KNN'].
dilated_group (bool): Whether to use dilated ball query.
Defaults to False.
use_xyz (bool): Whether to use xyz as point features.
Defaults to True.
pool_mode (str): Type of pooling method. Defaults to 'max'.
normalize_xyz (bool): If ball query, whether to normalize local XYZ
with radius. Defaults to False.
grouper_return_grouped_xyz (bool): Whether to return grouped xyz in
`QueryAndGroup`. Defaults to False.
grouper_return_grouped_idx (bool): Whether to return grouped idx in
`QueryAndGroup`. Defaults to False.
"""
def __init__(self,
radii: List[float],
sample_nums: List[int],
mlp_channels: List[List[int]],
knn_modes: List[str] = ['F-KNN'],
dilated_group: bool = False,
use_xyz: bool = True,
pool_mode: str = 'max',
normalize_xyz: bool = False,
grouper_return_grouped_xyz: bool = False,
grouper_return_grouped_idx: bool = False) -> None:
super(BaseDGCNNGFModule, self).__init__()
assert len(sample_nums) == len(
mlp_channels
), 'Num_samples and mlp_channels should have the same length.'
assert pool_mode in ['max', 'avg'
], "Pool_mode should be one of ['max', 'avg']."
assert isinstance(knn_modes, list) or isinstance(
knn_modes, tuple), 'The type of knn_modes should be list or tuple.'
if isinstance(mlp_channels, tuple):
mlp_channels = list(map(list, mlp_channels))
self.mlp_channels = mlp_channels
self.pool_mode = pool_mode
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
self.knn_modes = knn_modes
for i in range(len(sample_nums)):
sample_num = sample_nums[i]
if sample_num is not None:
if self.knn_modes[i] == 'D-KNN':
grouper = QueryAndGroup(
radii[i],
sample_num,
use_xyz=use_xyz,
normalize_xyz=normalize_xyz,
return_grouped_xyz=grouper_return_grouped_xyz,
return_grouped_idx=True)
else:
grouper = QueryAndGroup(
radii[i],
sample_num,
use_xyz=use_xyz,
normalize_xyz=normalize_xyz,
return_grouped_xyz=grouper_return_grouped_xyz,
return_grouped_idx=grouper_return_grouped_idx)
else:
grouper = GroupAll(use_xyz)
self.groupers.append(grouper)
def _pool_features(self, features: Tensor) -> Tensor:
"""Perform feature aggregation using pooling operation.
Args:
features (Tensor): (B, C, N, K) Features of locally grouped
points before pooling.
Returns:
Tensor: (B, C, N) Pooled features aggregating local information.
"""
if self.pool_mode == 'max':
# (B, C, N, 1)
new_features = F.max_pool2d(
features, kernel_size=[1, features.size(3)])
elif self.pool_mode == 'avg':
# (B, C, N, 1)
new_features = F.avg_pool2d(
features, kernel_size=[1, features.size(3)])
else:
raise NotImplementedError
return new_features.squeeze(-1).contiguous()
def forward(self, points: Tensor) -> Tensor:
"""forward.
Args:
points (Tensor): (B, N, C) Input points.
Returns:
Tensor: (B, N, C1) New points generated from each graph
feature module.
"""
new_points_list = [points]
for i in range(len(self.groupers)):
new_points = new_points_list[i]
new_points_trans = new_points.transpose(
1, 2).contiguous() # (B, C, N)
if self.knn_modes[i] == 'D-KNN':
# (B, N, C) -> (B, N, K)
idx = self.groupers[i](new_points[..., -3:].contiguous(),
new_points[..., -3:].contiguous())[-1]
grouped_results = grouping_operation(
new_points_trans, idx) # (B, C, N) -> (B, C, N, K)
grouped_results -= new_points_trans.unsqueeze(-1)
else:
grouped_results = self.groupers[i](
new_points, new_points) # (B, N, C) -> (B, C, N, K)
new_points = new_points_trans.unsqueeze(-1).repeat(
1, 1, 1, grouped_results.shape[-1])
new_points = torch.cat([grouped_results, new_points], dim=1)
# (B, mlp[-1], N, K)
new_points = self.mlps[i](new_points)
# (B, mlp[-1], N)
new_points = self._pool_features(new_points)
new_points = new_points.transpose(1, 2).contiguous()
new_points_list.append(new_points)
return new_points
class DGCNNGFModule(BaseDGCNNGFModule):
"""Point graph feature module used in DGCNN.
Args:
mlp_channels (List[int]): Specify of the dgcnn before the global
pooling for each graph feature module.
num_sample (int, optional): Number of samples in each knn or ball
query. Defaults to None.
knn_mode (str): Type of KNN method, valid mode ['F-KNN', 'D-KNN'].
Defaults to 'F-KNN'.
radius (float, optional): Radius to group with. Defaults to None.
dilated_group (bool): Whether to use dilated ball query.
Defaults to False.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN2d').
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Defaults to dict(type='ReLU').
use_xyz (bool): Whether to use xyz as point features. Defaults to True.
pool_mode (str): Type of pooling method. Defaults to 'max'.
normalize_xyz (bool): If ball query, whether to normalize local XYZ
with radius. Defaults to False.
bias (bool or str): If specified as `auto`, it will be decided by
`norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
otherwise False. Defaults to 'auto'.
"""
def __init__(self,
mlp_channels: List[int],
num_sample: Optional[int] = None,
knn_mode: str = 'F-KNN',
radius: Optional[float] = None,
dilated_group: bool = False,
norm_cfg: ConfigType = dict(type='BN2d'),
act_cfg: ConfigType = dict(type='ReLU'),
use_xyz: bool = True,
pool_mode: str = 'max',
normalize_xyz: bool = False,
bias: Union[bool, str] = 'auto') -> None:
super(DGCNNGFModule, self).__init__(
mlp_channels=[mlp_channels],
sample_nums=[num_sample],
knn_modes=[knn_mode],
radii=[radius],
use_xyz=use_xyz,
pool_mode=pool_mode,
normalize_xyz=normalize_xyz,
dilated_group=dilated_group)
for i in range(len(self.mlp_channels)):
mlp_channel = self.mlp_channels[i]
mlp = nn.Sequential()
for i in range(len(mlp_channel) - 1):
mlp.add_module(
f'layer{i}',
ConvModule(
mlp_channel[i],
mlp_channel[i + 1],
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=bias))
self.mlps.append(mlp)