File size: 3,901 Bytes
34d1f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from https://github.com/SamsungLabs/tr3d/blob/master/mmdet3d/models/necks/tr3d_neck.py # noqa
from typing import List, Tuple

try:
    import MinkowskiEngine as ME
    from MinkowskiEngine import SparseTensor
except ImportError:
    # Please follow getting_started.md to install MinkowskiEngine.
    ME = SparseTensor = None
    pass

from mmengine.model import BaseModule
from torch import nn

from mmdet3d.registry import MODELS


@MODELS.register_module()
class TR3DNeck(BaseModule):
    r"""Neck of `TR3D <https://arxiv.org/abs/2302.02858>`_.

    Args:
        in_channels (tuple[int]): Number of channels in input tensors.
        out_channels (int): Number of channels in output tensors.
    """

    def __init__(self, in_channels: Tuple[int], out_channels: int):
        super(TR3DNeck, self).__init__()
        self._init_layers(in_channels[1:], out_channels)

    def _init_layers(self, in_channels: Tuple[int], out_channels: int):
        """Initialize layers.

        Args:
            in_channels (tuple[int]): Number of channels in input tensors.
            out_channels (int): Number of channels in output tensors.
        """
        for i in range(len(in_channels)):
            if i > 0:
                self.add_module(
                    f'up_block_{i}',
                    self._make_block(in_channels[i], in_channels[i - 1], True,
                                     2))
            if i < len(in_channels) - 1:
                self.add_module(
                    f'lateral_block_{i}',
                    self._make_block(in_channels[i], in_channels[i]))
                self.add_module(f'out_block_{i}',
                                self._make_block(in_channels[i], out_channels))

    def init_weights(self):
        """Initialize weights."""
        for m in self.modules():
            if isinstance(m, ME.MinkowskiConvolution):
                ME.utils.kaiming_normal_(
                    m.kernel, mode='fan_out', nonlinearity='relu')

            if isinstance(m, ME.MinkowskiBatchNorm):
                nn.init.constant_(m.bn.weight, 1)
                nn.init.constant_(m.bn.bias, 0)

    def forward(self, x: List[SparseTensor]) -> List[SparseTensor]:
        """Forward pass.

        Args:
            x (list[SparseTensor]): Features from the backbone.

        Returns:
            List[Tensor]: Output features from the neck.
        """
        x = x[1:]
        outs = []
        inputs = x
        x = inputs[-1]
        for i in range(len(inputs) - 1, -1, -1):
            if i < len(inputs) - 1:
                x = self.__getattr__(f'up_block_{i + 1}')(x)
                x = inputs[i] + x
                x = self.__getattr__(f'lateral_block_{i}')(x)
                out = self.__getattr__(f'out_block_{i}')(x)
                outs.append(out)
        return outs[::-1]

    @staticmethod
    def _make_block(in_channels: int,
                    out_channels: int,
                    generative: bool = False,
                    stride: int = 1) -> nn.Module:
        """Construct Conv-Norm-Act block.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            generative (bool): Use generative convolution if True.
                Defaults to False.
            stride (int): Stride of the convolution. Defaults to 1.

        Returns:
            torch.nn.Module: With corresponding layers.
        """
        conv = ME.MinkowskiGenerativeConvolutionTranspose if generative \
            else ME.MinkowskiConvolution
        return nn.Sequential(
            conv(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                dimension=3), ME.MinkowskiBatchNorm(out_channels),
            ME.MinkowskiReLU(inplace=True))