Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,447 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from operator import attrgetter
from typing import List, Union
import torch
import torch.nn as nn
def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""Code borrowed from mmcv 2.0.1, so that this feature can be used for old
mmcv versions.
Implementation based on https://arxiv.org/abs/2305.11624
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning"
It leverages the associative law between convolution and affine transform,
i.e., normalize (weight conv feature) = (normalize weight) conv feature.
It works for Eval mode of ConvBN blocks during validation, and can be used
for training as well. It reduces memory and computation cost.
Args:
bn (_BatchNorm): a BatchNorm module.
conv (nn._ConvNd): a conv module
x (torch.Tensor): Input feature map.
"""
# These lines of code are designed to deal with various cases
# like bn without affine transform, and conv without bias
weight_on_the_fly = conv.weight
if conv.bias is not None:
bias_on_the_fly = conv.bias
else:
bias_on_the_fly = torch.zeros_like(bn.running_var)
if bn.weight is not None:
bn_weight = bn.weight
else:
bn_weight = torch.ones_like(bn.running_var)
if bn.bias is not None:
bn_bias = bn.bias
else:
bn_bias = torch.zeros_like(bn.running_var)
# shape of [C_out, 1, 1, 1] in Conv2d
weight_coeff = torch.rsqrt(bn.running_var +
bn.eps).reshape([-1] + [1] *
(len(conv.weight.shape) - 1))
# shape of [C_out, 1, 1, 1] in Conv2d
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff
# shape of [C_out, C_in, k, k] in Conv2d
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly
# shape of [C_out] in Conv2d
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\
(bias_on_the_fly - bn.running_mean)
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly)
def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm,
conv: nn.modules.conv._ConvNd,
x: torch.Tensor):
"""This function controls whether to use `efficient_conv_bn_eval_forward`.
If the following `bn` is in `eval` mode, then we turn on the special
`efficient_conv_bn_eval_forward`.
"""
if not bn.training:
# bn in eval mode
output = efficient_conv_bn_eval_forward(bn, conv, x)
return output
else:
conv_out = conv._conv_forward(x, conv.weight, conv.bias)
return bn(conv_out)
def efficient_conv_bn_eval_graph_transform(fx_model):
"""Find consecutive conv+bn calls in the graph, inplace modify the graph
with the fused operation."""
modules = dict(fx_model.named_modules())
patterns = [(torch.nn.modules.conv._ConvNd,
torch.nn.modules.batchnorm._BatchNorm)]
pairs = []
# Iterate through nodes in the graph to find ConvBN blocks
for node in fx_model.graph.nodes:
# If our current node isn't calling a Module then we can ignore it.
if node.op != 'call_module':
continue
target_module = modules[node.target]
found_pair = False
for conv_class, bn_class in patterns:
if isinstance(target_module, bn_class):
source_module = modules[node.args[0].target]
if isinstance(source_module, conv_class):
found_pair = True
# Not a conv-BN pattern or output of conv is used by other nodes
if not found_pair or len(node.args[0].users) > 1:
continue
# Find a pair of conv and bn computation nodes to optimize
conv_node = node.args[0]
bn_node = node
pairs.append([conv_node, bn_node])
for conv_node, bn_node in pairs:
# set insertion point
fx_model.graph.inserting_before(conv_node)
# create `get_attr` node to access modules
# note that we directly call `create_node` to fill the `name`
# argument. `fx_model.graph.get_attr` and
# `fx_model.graph.call_function` does not allow the `name` argument.
conv_get_node = fx_model.graph.create_node(
op='get_attr', target=conv_node.target, name='get_conv')
bn_get_node = fx_model.graph.create_node(
op='get_attr', target=bn_node.target, name='get_bn')
# prepare args for the fused function
args = (bn_get_node, conv_get_node, conv_node.args[0])
# create a new node
new_node = fx_model.graph.create_node(
op='call_function',
target=efficient_conv_bn_eval_control,
args=args,
name='efficient_conv_bn_eval')
# this node replaces the original conv + bn, and therefore
# should replace the uses of bn_node
bn_node.replace_all_uses_with(new_node)
# take care of the deletion order:
# delete bn_node first, and then conv_node
fx_model.graph.erase_node(bn_node)
fx_model.graph.erase_node(conv_node)
# regenerate the code
fx_model.graph.lint()
fx_model.recompile()
def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module):
import torch.fx as fx
# currently we use `fx.symbolic_trace` to trace models.
# in the future, we might turn to pytorch 2.0 compile infrastructure to
# get the `fx.GraphModule` IR. Nonetheless, the graph transform function
# can remain unchanged. We just need to change the way
# we get `fx.GraphModule`.
fx_model: fx.GraphModule = fx.symbolic_trace(model)
efficient_conv_bn_eval_graph_transform(fx_model)
model.forward = fx_model.forward
def turn_on_efficient_conv_bn_eval(model: torch.nn.Module,
modules: Union[List[str], str]):
if isinstance(modules, str):
modules = [modules]
for module_name in modules:
module = attrgetter(module_name)(model)
turn_on_efficient_conv_bn_eval_for_single_model(module)
|