File size: 6,447 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from operator import attrgetter
from typing import List, Union

import torch
import torch.nn as nn


def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
                                   conv: nn.modules.conv._ConvNd,
                                   x: torch.Tensor):
    """Code borrowed from mmcv 2.0.1, so that this feature can be used for old
    mmcv versions.

    Implementation based on https://arxiv.org/abs/2305.11624
    "Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
    It leverages the associative law between convolution and affine transform,
    i.e., normalize (weight conv feature) = (normalize weight) conv feature.
    It works for Eval mode of ConvBN blocks during validation, and can be used
    for training as well. It reduces memory and computation cost.
    Args:
        bn (_BatchNorm): a BatchNorm module.
        conv (nn._ConvNd): a conv module
        x (torch.Tensor): Input feature map.
    """
    # These lines of code are designed to deal with various cases
    # like bn without affine transform, and conv without bias
    weight_on_the_fly = conv.weight
    if conv.bias is not None:
        bias_on_the_fly = conv.bias
    else:
        bias_on_the_fly = torch.zeros_like(bn.running_var)

    if bn.weight is not None:
        bn_weight = bn.weight
    else:
        bn_weight = torch.ones_like(bn.running_var)

    if bn.bias is not None:
        bn_bias = bn.bias
    else:
        bn_bias = torch.zeros_like(bn.running_var)

    # shape of [C_out, 1, 1, 1] in Conv2d
    weight_coeff = torch.rsqrt(bn.running_var +
                               bn.eps).reshape([-1] + [1] *
                                               (len(conv.weight.shape) - 1))
    # shape of [C_out, 1, 1, 1] in Conv2d
    coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff

    # shape of [C_out, C_in, k, k] in Conv2d
    weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
    # shape of [C_out] in Conv2d
    bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
        (bias_on_the_fly - bn.running_mean)

    return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)


def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
                                   conv: nn.modules.conv._ConvNd,
                                   x: torch.Tensor):
    """This function controls whether to use `efficient_conv_bn_eval_forward`.

    If the following `bn` is in `eval` mode, then we turn on the special
    `efficient_conv_bn_eval_forward`.
    """
    if not bn.training:
        # bn in eval mode
        output = efficient_conv_bn_eval_forward(bn, conv, x)
        return output
    else:
        conv_out = conv._conv_forward(x, conv.weight, conv.bias)
        return bn(conv_out)


def efficient_conv_bn_eval_graph_transform(fx_model):
    """Find consecutive conv+bn calls in the graph, inplace modify the graph
    with the fused operation."""
    modules = dict(fx_model.named_modules())

    patterns = [(torch.nn.modules.conv._ConvNd,
                 torch.nn.modules.batchnorm._BatchNorm)]

    pairs = []
    # Iterate through nodes in the graph to find ConvBN blocks
    for node in fx_model.graph.nodes:
        # If our current node isn't calling a Module then we can ignore it.
        if node.op != 'call_module':
            continue
        target_module = modules[node.target]
        found_pair = False
        for conv_class, bn_class in patterns:
            if isinstance(target_module, bn_class):
                source_module = modules[node.args[0].target]
                if isinstance(source_module, conv_class):
                    found_pair = True
        # Not a conv-BN pattern or output of conv is used by other nodes
        if not found_pair or len(node.args[0].users) > 1:
            continue

        # Find a pair of conv and bn computation nodes to optimize
        conv_node = node.args[0]
        bn_node = node
        pairs.append([conv_node, bn_node])

    for conv_node, bn_node in pairs:
        # set insertion point
        fx_model.graph.inserting_before(conv_node)
        # create `get_attr` node to access modules
        # note that we directly call `create_node` to fill the `name`
        # argument. `fx_model.graph.get_attr` and
        # `fx_model.graph.call_function` does not allow the `name` argument.
        conv_get_node = fx_model.graph.create_node(
            op='get_attr', target=conv_node.target, name='get_conv')
        bn_get_node = fx_model.graph.create_node(
            op='get_attr', target=bn_node.target, name='get_bn')
        # prepare args for the fused function
        args = (bn_get_node, conv_get_node, conv_node.args[0])
        # create a new node
        new_node = fx_model.graph.create_node(
            op='call_function',
            target=efficient_conv_bn_eval_control,
            args=args,
            name='efficient_conv_bn_eval')
        # this node replaces the original conv + bn, and therefore
        # should replace the uses of bn_node
        bn_node.replace_all_uses_with(new_node)
        # take care of the deletion order:
        # delete bn_node first, and then conv_node
        fx_model.graph.erase_node(bn_node)
        fx_model.graph.erase_node(conv_node)

    # regenerate the code
    fx_model.graph.lint()
    fx_model.recompile()


def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
    import torch.fx as fx

    # currently we use `fx.symbolic_trace` to trace models.
    # in the future, we might turn to pytorch 2.0 compile infrastructure to
    # get the `fx.GraphModule` IR. Nonetheless, the graph transform function
    # can remain unchanged. We just need to change the way
    # we get `fx.GraphModule`.
    fx_model: fx.GraphModule = fx.symbolic_trace(model)
    efficient_conv_bn_eval_graph_transform(fx_model)
    model.forward = fx_model.forward


def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
                                   modules: Union[List[str], str]):
    if isinstance(modules, str):
        modules = [modules]
    for module_name in modules:
        module = attrgetter(module_name)(model)
        turn_on_efficient_conv_bn_eval_for_single_model(module)