Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from operator import attrgetter | |
from typing import List, Union | |
import torch | |
import torch.nn as nn | |
def efficient_conv_bn_eval_forward(bn: nn.modules.batchnorm._BatchNorm, | |
conv: nn.modules.conv._ConvNd, | |
x: torch.Tensor): | |
"""Code borrowed from mmcv 2.0.1, so that this feature can be used for old | |
mmcv versions. | |
Implementation based on https://arxiv.org/abs/2305.11624 | |
"Tune-Mode ConvBN Blocks For Efficient Transfer Learning" | |
It leverages the associative law between convolution and affine transform, | |
i.e., normalize (weight conv feature) = (normalize weight) conv feature. | |
It works for Eval mode of ConvBN blocks during validation, and can be used | |
for training as well. It reduces memory and computation cost. | |
Args: | |
bn (_BatchNorm): a BatchNorm module. | |
conv (nn._ConvNd): a conv module | |
x (torch.Tensor): Input feature map. | |
""" | |
# These lines of code are designed to deal with various cases | |
# like bn without affine transform, and conv without bias | |
weight_on_the_fly = conv.weight | |
if conv.bias is not None: | |
bias_on_the_fly = conv.bias | |
else: | |
bias_on_the_fly = torch.zeros_like(bn.running_var) | |
if bn.weight is not None: | |
bn_weight = bn.weight | |
else: | |
bn_weight = torch.ones_like(bn.running_var) | |
if bn.bias is not None: | |
bn_bias = bn.bias | |
else: | |
bn_bias = torch.zeros_like(bn.running_var) | |
# shape of [C_out, 1, 1, 1] in Conv2d | |
weight_coeff = torch.rsqrt(bn.running_var + | |
bn.eps).reshape([-1] + [1] * | |
(len(conv.weight.shape) - 1)) | |
# shape of [C_out, 1, 1, 1] in Conv2d | |
coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff | |
# shape of [C_out, C_in, k, k] in Conv2d | |
weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly | |
# shape of [C_out] in Conv2d | |
bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() *\ | |
(bias_on_the_fly - bn.running_mean) | |
return conv._conv_forward(x, weight_on_the_fly, bias_on_the_fly) | |
def efficient_conv_bn_eval_control(bn: nn.modules.batchnorm._BatchNorm, | |
conv: nn.modules.conv._ConvNd, | |
x: torch.Tensor): | |
"""This function controls whether to use `efficient_conv_bn_eval_forward`. | |
If the following `bn` is in `eval` mode, then we turn on the special | |
`efficient_conv_bn_eval_forward`. | |
""" | |
if not bn.training: | |
# bn in eval mode | |
output = efficient_conv_bn_eval_forward(bn, conv, x) | |
return output | |
else: | |
conv_out = conv._conv_forward(x, conv.weight, conv.bias) | |
return bn(conv_out) | |
def efficient_conv_bn_eval_graph_transform(fx_model): | |
"""Find consecutive conv+bn calls in the graph, inplace modify the graph | |
with the fused operation.""" | |
modules = dict(fx_model.named_modules()) | |
patterns = [(torch.nn.modules.conv._ConvNd, | |
torch.nn.modules.batchnorm._BatchNorm)] | |
pairs = [] | |
# Iterate through nodes in the graph to find ConvBN blocks | |
for node in fx_model.graph.nodes: | |
# If our current node isn't calling a Module then we can ignore it. | |
if node.op != 'call_module': | |
continue | |
target_module = modules[node.target] | |
found_pair = False | |
for conv_class, bn_class in patterns: | |
if isinstance(target_module, bn_class): | |
source_module = modules[node.args[0].target] | |
if isinstance(source_module, conv_class): | |
found_pair = True | |
# Not a conv-BN pattern or output of conv is used by other nodes | |
if not found_pair or len(node.args[0].users) > 1: | |
continue | |
# Find a pair of conv and bn computation nodes to optimize | |
conv_node = node.args[0] | |
bn_node = node | |
pairs.append([conv_node, bn_node]) | |
for conv_node, bn_node in pairs: | |
# set insertion point | |
fx_model.graph.inserting_before(conv_node) | |
# create `get_attr` node to access modules | |
# note that we directly call `create_node` to fill the `name` | |
# argument. `fx_model.graph.get_attr` and | |
# `fx_model.graph.call_function` does not allow the `name` argument. | |
conv_get_node = fx_model.graph.create_node( | |
op='get_attr', target=conv_node.target, name='get_conv') | |
bn_get_node = fx_model.graph.create_node( | |
op='get_attr', target=bn_node.target, name='get_bn') | |
# prepare args for the fused function | |
args = (bn_get_node, conv_get_node, conv_node.args[0]) | |
# create a new node | |
new_node = fx_model.graph.create_node( | |
op='call_function', | |
target=efficient_conv_bn_eval_control, | |
args=args, | |
name='efficient_conv_bn_eval') | |
# this node replaces the original conv + bn, and therefore | |
# should replace the uses of bn_node | |
bn_node.replace_all_uses_with(new_node) | |
# take care of the deletion order: | |
# delete bn_node first, and then conv_node | |
fx_model.graph.erase_node(bn_node) | |
fx_model.graph.erase_node(conv_node) | |
# regenerate the code | |
fx_model.graph.lint() | |
fx_model.recompile() | |
def turn_on_efficient_conv_bn_eval_for_single_model(model: torch.nn.Module): | |
import torch.fx as fx | |
# currently we use `fx.symbolic_trace` to trace models. | |
# in the future, we might turn to pytorch 2.0 compile infrastructure to | |
# get the `fx.GraphModule` IR. Nonetheless, the graph transform function | |
# can remain unchanged. We just need to change the way | |
# we get `fx.GraphModule`. | |
fx_model: fx.GraphModule = fx.symbolic_trace(model) | |
efficient_conv_bn_eval_graph_transform(fx_model) | |
model.forward = fx_model.forward | |
def turn_on_efficient_conv_bn_eval(model: torch.nn.Module, | |
modules: Union[List[str], str]): | |
if isinstance(modules, str): | |
modules = [modules] | |
for module_name in modules: | |
module = attrgetter(module_name)(model) | |
turn_on_efficient_conv_bn_eval_for_single_model(module) | |