|
|
|
from typing import Sequence |
|
|
|
from mmcv.cnn.bricks import ConvModule |
|
from torch import Tensor |
|
|
|
from mmdet3d.models.layers import DGCNNFPModule |
|
from mmdet3d.registry import MODELS |
|
from .decode_head import Base3DDecodeHead |
|
|
|
|
|
@MODELS.register_module() |
|
class DGCNNHead(Base3DDecodeHead): |
|
r"""DGCNN decoder head. |
|
|
|
Decoder head used in `DGCNN <https://arxiv.org/abs/1801.07829>`_. |
|
Refer to the |
|
`reimplementation code <https://github.com/AnTao97/dgcnn.pytorch>`_. |
|
|
|
Args: |
|
fp_channels (Sequence[int]): Tuple of mlp channels in feature |
|
propagation (FP) modules. Defaults to (1216, 512). |
|
""" |
|
|
|
def __init__(self, fp_channels: Sequence[int] = (1216, 512), |
|
**kwargs) -> None: |
|
super(DGCNNHead, self).__init__(**kwargs) |
|
|
|
self.FP_module = DGCNNFPModule( |
|
mlp_channels=fp_channels, act_cfg=self.act_cfg) |
|
|
|
|
|
self.pre_seg_conv = ConvModule( |
|
fp_channels[-1], |
|
self.channels, |
|
kernel_size=1, |
|
bias=False, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=self.act_cfg) |
|
|
|
def _extract_input(self, feat_dict: dict) -> Tensor: |
|
"""Extract inputs from features dictionary. |
|
|
|
Args: |
|
feat_dict (dict): Feature dict from backbone. |
|
|
|
Returns: |
|
torch.Tensor: Points for decoder. |
|
""" |
|
fa_points = feat_dict['fa_points'] |
|
|
|
return fa_points |
|
|
|
def forward(self, feat_dict: dict) -> Tensor: |
|
"""Forward pass. |
|
|
|
Args: |
|
feat_dict (dict): Feature dict from backbone. |
|
|
|
Returns: |
|
Tensor: Segmentation map of shape [B, num_classes, N]. |
|
""" |
|
fa_points = self._extract_input(feat_dict) |
|
|
|
fp_points = self.FP_module(fa_points) |
|
fp_points = fp_points.transpose(1, 2).contiguous() |
|
output = self.pre_seg_conv(fp_points) |
|
output = self.cls_seg(output) |
|
|
|
return output |
|
|