File size: 6,938 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy

import numpy as np
import torch
import torch.nn as nn
from mmengine.logging import print_log
from mmengine.model import BaseModule
from torch import Tensor

from .utils import expand_rates, get_single_padding


class BaseConvRFSearchOp(BaseModule):
    """Based class of ConvRFSearchOp.

    Args:
        op_layer (nn.Module): pytorch module, e,g, Conv2d
        global_config (dict): config dict.
    """

    def __init__(self, op_layer: nn.Module, global_config: dict):
        super().__init__()
        self.op_layer = op_layer
        self.global_config = global_config

    def normlize(self, weights: nn.Parameter) -> nn.Parameter:
        """Normalize weights.

        Args:
            weights (nn.Parameter): Weights to be normalized.

        Returns:
            nn.Parameters: Normalized weights.
        """
        abs_weights = torch.abs(weights)
        normalized_weights = abs_weights / torch.sum(abs_weights)
        return normalized_weights


class Conv2dRFSearchOp(BaseConvRFSearchOp):
    """Enable Conv2d with receptive field searching ability.

    Args:
        op_layer (nn.Module): pytorch module, e,g, Conv2d
        global_config (dict): config dict. Defaults to None.
            By default this must include:

            - "init_alphas": The value for initializing weights of each branch.
            - "num_branches": The controller of the size of
              search space (the number of branches).
            - "exp_rate": The controller of the sparsity of search space.
            - "mmin": The minimum dilation rate.
            - "mmax": The maximum dilation rate.

            Extra keys may exist, but are used by RFSearchHook, e.g., "step",
            "max_step", "search_interval", and "skip_layer".
        verbose (bool): Determines whether to print rf-next
            related logging messages.
            Defaults to True.
    """

    def __init__(self,
                 op_layer: nn.Module,
                 global_config: dict,
                 verbose: bool = True):
        super().__init__(op_layer, global_config)
        assert global_config is not None, 'global_config is None'
        self.num_branches = global_config['num_branches']
        assert self.num_branches in [2, 3]
        self.verbose = verbose
        init_dilation = op_layer.dilation
        self.dilation_rates = expand_rates(init_dilation, global_config)
        if self.op_layer.kernel_size[
                0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
            self.dilation_rates = [(op_layer.dilation[0], r[1])
                                   for r in self.dilation_rates]
        if self.op_layer.kernel_size[
                1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
            self.dilation_rates = [(r[0], op_layer.dilation[1])
                                   for r in self.dilation_rates]

        self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches))
        if self.verbose:
            print_log(f'Expand as {self.dilation_rates}', 'current')
        nn.init.constant_(self.branch_weights, global_config['init_alphas'])

    def forward(self, input: Tensor) -> Tensor:
        norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
        if len(self.dilation_rates) == 1:
            outputs = [
                nn.functional.conv2d(
                    input,
                    weight=self.op_layer.weight,
                    bias=self.op_layer.bias,
                    stride=self.op_layer.stride,
                    padding=self.get_padding(self.dilation_rates[0]),
                    dilation=self.dilation_rates[0],
                    groups=self.op_layer.groups,
                )
            ]
        else:
            outputs = [
                nn.functional.conv2d(
                    input,
                    weight=self.op_layer.weight,
                    bias=self.op_layer.bias,
                    stride=self.op_layer.stride,
                    padding=self.get_padding(r),
                    dilation=r,
                    groups=self.op_layer.groups,
                ) * norm_w[i] for i, r in enumerate(self.dilation_rates)
            ]
        output = outputs[0]
        for i in range(1, len(self.dilation_rates)):
            output += outputs[i]
        return output

    def estimate_rates(self) -> None:
        """Estimate new dilation rate based on trained branch_weights."""
        norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)])
        if self.verbose:
            print_log(
                'Estimate dilation {} with weight {}.'.format(
                    self.dilation_rates,
                    norm_w.detach().cpu().numpy().tolist()), 'current')

        sum0, sum1, w_sum = 0, 0, 0
        for i in range(len(self.dilation_rates)):
            sum0 += norm_w[i].item() * self.dilation_rates[i][0]
            sum1 += norm_w[i].item() * self.dilation_rates[i][1]
            w_sum += norm_w[i].item()
        estimated = [
            np.clip(
                int(round(sum0 / w_sum)), self.global_config['mmin'],
                self.global_config['mmax']).item(),
            np.clip(
                int(round(sum1 / w_sum)), self.global_config['mmin'],
                self.global_config['mmax']).item()
        ]
        self.op_layer.dilation = tuple(estimated)
        self.op_layer.padding = self.get_padding(self.op_layer.dilation)
        self.dilation_rates = [tuple(estimated)]
        if self.verbose:
            print_log(f'Estimate as {tuple(estimated)}', 'current')

    def expand_rates(self) -> None:
        """Expand dilation rate."""
        dilation = self.op_layer.dilation
        dilation_rates = expand_rates(dilation, self.global_config)
        if self.op_layer.kernel_size[
                0] == 1 or self.op_layer.kernel_size[0] % 2 == 0:
            dilation_rates = [(dilation[0], r[1]) for r in dilation_rates]
        if self.op_layer.kernel_size[
                1] == 1 or self.op_layer.kernel_size[1] % 2 == 0:
            dilation_rates = [(r[0], dilation[1]) for r in dilation_rates]

        self.dilation_rates = copy.deepcopy(dilation_rates)
        if self.verbose:
            print_log(f'Expand as {self.dilation_rates}', 'current')
        nn.init.constant_(self.branch_weights,
                          self.global_config['init_alphas'])

    def get_padding(self, dilation) -> tuple:
        padding = (get_single_padding(self.op_layer.kernel_size[0],
                                      self.op_layer.stride[0], dilation[0]),
                   get_single_padding(self.op_layer.kernel_size[1],
                                      self.op_layer.stride[1], dilation[1]))
        return padding