File size: 5,804 Bytes
2c924d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version


@CONV_LAYERS.register_module(name='SAC')
class SAConv2d(ConvAWS2d):
    """SAC (Switchable Atrous Convolution)

    This is an implementation of SAC in DetectoRS
    (https://arxiv.org/pdf/2006.02334.pdf).

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of
            the input. Default: 0
        padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
        dilation (int or tuple, optional): Spacing between kernel elements.
            Default: 1
        groups (int, optional): Number of blocked connections from input
            channels to output channels. Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the
            output. Default: ``True``
        use_deform: If ``True``, replace convolution with deformable
            convolution. Default: ``False``.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 use_deform=False):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias)
        self.use_deform = use_deform
        self.switch = nn.Conv2d(
            self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
        self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
        self.pre_context = nn.Conv2d(
            self.in_channels, self.in_channels, kernel_size=1, bias=True)
        self.post_context = nn.Conv2d(
            self.out_channels, self.out_channels, kernel_size=1, bias=True)
        if self.use_deform:
            self.offset_s = nn.Conv2d(
                self.in_channels,
                18,
                kernel_size=3,
                padding=1,
                stride=stride,
                bias=True)
            self.offset_l = nn.Conv2d(
                self.in_channels,
                18,
                kernel_size=3,
                padding=1,
                stride=stride,
                bias=True)
        self.init_weights()

    def init_weights(self):
        constant_init(self.switch, 0, bias=1)
        self.weight_diff.data.zero_()
        constant_init(self.pre_context, 0)
        constant_init(self.post_context, 0)
        if self.use_deform:
            constant_init(self.offset_s, 0)
            constant_init(self.offset_l, 0)

    def forward(self, x):
        # pre-context
        avg_x = F.adaptive_avg_pool2d(x, output_size=1)
        avg_x = self.pre_context(avg_x)
        avg_x = avg_x.expand_as(x)
        x = x + avg_x
        # switch
        avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
        avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
        switch = self.switch(avg_x)
        # sac
        weight = self._get_weight(self.weight)
        zero_bias = torch.zeros(
            self.out_channels, device=weight.device, dtype=weight.dtype)

        if self.use_deform:
            offset = self.offset_s(avg_x)
            out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
                                  self.dilation, self.groups, 1)
        else:
            if (TORCH_VERSION == 'parrots'
                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
                out_s = super().conv2d_forward(x, weight)
            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
                # bias is a required argument of _conv_forward in torch 1.8.0
                out_s = super()._conv_forward(x, weight, zero_bias)
            else:
                out_s = super()._conv_forward(x, weight)
        ori_p = self.padding
        ori_d = self.dilation
        self.padding = tuple(3 * p for p in self.padding)
        self.dilation = tuple(3 * d for d in self.dilation)
        weight = weight + self.weight_diff
        if self.use_deform:
            offset = self.offset_l(avg_x)
            out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
                                  self.dilation, self.groups, 1)
        else:
            if (TORCH_VERSION == 'parrots'
                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
                out_l = super().conv2d_forward(x, weight)
            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
                # bias is a required argument of _conv_forward in torch 1.8.0
                out_l = super()._conv_forward(x, weight, zero_bias)
            else:
                out_l = super()._conv_forward(x, weight)

        out = switch * out_s + (1 - switch) * out_l
        self.padding = ori_p
        self.dilation = ori_d
        # post-context
        avg_x = F.adaptive_avg_pool2d(out, output_size=1)
        avg_x = self.post_context(avg_x)
        avg_x = avg_x.expand_as(out)
        out = out + avg_x
        return out