File size: 5,078 Bytes
412c852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmengine.model import BaseModule

from mmseg.registry import MODELS
from ..utils import resize


@MODELS.register_module()
class JPU(BaseModule):
    """FastFCN: Rethinking Dilated Convolution in the Backbone
    for Semantic Segmentation.

    This Joint Pyramid Upsampling (JPU) neck is the implementation of
    `FastFCN <https://arxiv.org/abs/1903.11816>`_.

    Args:
        in_channels (Tuple[int], optional): The number of input channels
            for each convolution operations before upsampling.
            Default: (512, 1024, 2048).
        mid_channels (int): The number of output channels of JPU.
            Default: 512.
        start_level (int): Index of the start input backbone level used to
            build the feature pyramid. Default: 0.
        end_level (int): Index of the end input backbone level (exclusive) to
            build the feature pyramid. Default: -1, which means the last level.
        dilations (tuple[int]): Dilation rate of each Depthwise
            Separable ConvModule. Default: (1, 2, 4, 8).
        align_corners (bool, optional): The align_corners argument of
            resize operation. Default: False.
        conv_cfg (dict | None): Config of conv layers.
            Default: None.
        norm_cfg (dict | None): Config of norm layers.
            Default: dict(type='BN').
        act_cfg (dict): Config of activation layers.
            Default: dict(type='ReLU').
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.
    """

    def __init__(self,
                 in_channels=(512, 1024, 2048),
                 mid_channels=512,
                 start_level=0,
                 end_level=-1,
                 dilations=(1, 2, 4, 8),
                 align_corners=False,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        assert isinstance(in_channels, tuple)
        assert isinstance(dilations, tuple)
        self.in_channels = in_channels
        self.mid_channels = mid_channels
        self.start_level = start_level
        self.num_ins = len(in_channels)
        if end_level == -1:
            self.backbone_end_level = self.num_ins
        else:
            self.backbone_end_level = end_level
            assert end_level <= len(in_channels)

        self.dilations = dilations
        self.align_corners = align_corners

        self.conv_layers = nn.ModuleList()
        self.dilation_layers = nn.ModuleList()
        for i in range(self.start_level, self.backbone_end_level):
            conv_layer = nn.Sequential(
                ConvModule(
                    self.in_channels[i],
                    self.mid_channels,
                    kernel_size=3,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg))
            self.conv_layers.append(conv_layer)
        for i in range(len(dilations)):
            dilation_layer = nn.Sequential(
                DepthwiseSeparableConvModule(
                    in_channels=(self.backbone_end_level - self.start_level) *
                    self.mid_channels,
                    out_channels=self.mid_channels,
                    kernel_size=3,
                    stride=1,
                    padding=dilations[i],
                    dilation=dilations[i],
                    dw_norm_cfg=norm_cfg,
                    dw_act_cfg=None,
                    pw_norm_cfg=norm_cfg,
                    pw_act_cfg=act_cfg))
            self.dilation_layers.append(dilation_layer)

    def forward(self, inputs):
        """Forward function."""
        assert len(inputs) == len(self.in_channels), 'Length of inputs must \
                                           be the same with self.in_channels!'

        feats = [
            self.conv_layers[i - self.start_level](inputs[i])
            for i in range(self.start_level, self.backbone_end_level)
        ]

        h, w = feats[0].shape[2:]
        for i in range(1, len(feats)):
            feats[i] = resize(
                feats[i],
                size=(h, w),
                mode='bilinear',
                align_corners=self.align_corners)

        feat = torch.cat(feats, dim=1)
        concat_feat = torch.cat([
            self.dilation_layers[i](feat) for i in range(len(self.dilations))
        ],
                                dim=1)

        outs = []

        # Default: outs[2] is the output of JPU for decoder head, outs[1] is
        # the feature map from backbone for auxiliary head. Additionally,
        # outs[0] can also be used for auxiliary head.
        for i in range(self.start_level, self.backbone_end_level - 1):
            outs.append(inputs[i])
        outs.append(concat_feat)
        return tuple(outs)