giantmonkeyTC
mm2
c2ca15f
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, List
import torch
from mmengine.model import BaseModule, normal_init
from torch import Tensor
from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing_utils import ConfigType, OptMultiConfig
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
1. The ``init_weights`` method is used to initialize decode_head's
model parameters. After segmentor initialization, ``init_weights``
is triggered when ``segmentor.init_weights()`` is called externally.
2. The ``loss`` method is used to calculate the loss of decode_head,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``loss_by_feat`` method
is called based on the feature maps to calculate the loss.
.. code:: text
loss(): forward() -> loss_by_feat()
3. The ``predict`` method is used to predict segmentation results,
which includes two steps: (1) the decode_head model performs forward
propagation to obtain the feature maps (2) The ``predict_by_feat`` method
is called based on the feature maps to predict segmentation results
including post-processing.
.. code:: text
predict(): forward() -> predict_by_feat()
Args:
channels (int): Channels after modules, before conv_seg.
num_classes (int): Number of classes.
dropout_ratio (float): Ratio of dropout layer. Defaults to 0.5.
conv_cfg (dict or :obj:`ConfigDict`): Config of conv layers.
Defaults to dict(type='Conv1d').
norm_cfg (dict or :obj:`ConfigDict`): Config of norm layers.
Defaults to dict(type='BN1d').
act_cfg (dict or :obj:`ConfigDict`): Config of activation layers.
Defaults to dict(type='ReLU').
loss_decode (dict or :obj:`ConfigDict`): Config of decode loss.
Defaults to dict(type='mmdet.CrossEntropyLoss', use_sigmoid=False,
class_weight=None, loss_weight=1.0).
conv_seg_kernel_size (int): The kernel size used in conv_seg.
Defaults to 1.
ignore_index (int): The label index to be ignored. When using masked
BCE loss, ignore_index should be set to None. Defaults to 255.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to None.
"""
def __init__(self,
channels: int,
num_classes: int,
dropout_ratio: float = 0.5,
conv_cfg: ConfigType = dict(type='Conv1d'),
norm_cfg: ConfigType = dict(type='BN1d'),
act_cfg: ConfigType = dict(type='ReLU'),
loss_decode: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
conv_seg_kernel_size: int = 1,
ignore_index: int = 255,
init_cfg: OptMultiConfig = None) -> None:
super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
self.channels = channels
self.num_classes = num_classes
self.dropout_ratio = dropout_ratio
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.loss_decode = MODELS.build(loss_decode)
self.ignore_index = ignore_index
self.conv_seg = self.build_conv_seg(
channels=channels,
num_classes=num_classes,
kernel_size=conv_seg_kernel_size)
if dropout_ratio > 0:
self.dropout = nn.Dropout(dropout_ratio)
else:
self.dropout = None
def init_weights(self) -> None:
"""Initialize weights of classification layer."""
super().init_weights()
normal_init(self.conv_seg, mean=0, std=0.01)
@abstractmethod
def forward(self, feats_dict: dict) -> Tensor:
"""Placeholder of forward function."""
pass
def build_conv_seg(self, channels: int, num_classes: int,
kernel_size: int) -> nn.Module:
"""Build Convolutional Segmentation Layers."""
return nn.Conv1d(channels, num_classes, kernel_size=kernel_size)
def cls_seg(self, feat: Tensor) -> Tensor:
"""Classify each points."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.conv_seg(feat)
return output
def loss(self, inputs: dict, batch_data_samples: SampleList,
train_cfg: ConfigType) -> Dict[str, Tensor]:
"""Forward function for training.
Args:
inputs (dict): Feature dict from backbone.
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
train_cfg (dict or :obj:`ConfigDict`): The training config.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs: dict, batch_input_metas: List[dict],
test_cfg: ConfigType) -> Tensor:
"""Forward function for testing.
Args:
inputs (dict): Feature dict from backbone.
batch_input_metas (List[dict]): Meta information of a batch of
samples.
test_cfg (dict or :obj:`ConfigDict`): The testing config.
Returns:
Tensor: Output segmentation map.
"""
seg_logits = self.forward(inputs)
return seg_logits
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_semantic_segs = [
data_sample.gt_pts_seg.pts_semantic_mask
for data_sample in batch_data_samples
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logit: Tensor,
batch_data_samples: SampleList) -> Dict[str, Tensor]:
"""Compute semantic segmentation loss.
Args:
seg_logit (Tensor): Predicted per-point segmentation logits of
shape [B, num_classes, N].
batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
loss['loss_sem_seg'] = self.loss_decode(
seg_logit, seg_label, ignore_index=self.ignore_index)
return loss