Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from mmengine.logging import print_log | |
from mmengine.model import BaseModule | |
from torch import Tensor | |
from .utils import expand_rates, get_single_padding | |
class BaseConvRFSearchOp(BaseModule): | |
"""Based class of ConvRFSearchOp. | |
Args: | |
op_layer (nn.Module): pytorch module, e,g, Conv2d | |
global_config (dict): config dict. | |
""" | |
def __init__(self, op_layer: nn.Module, global_config: dict): | |
super().__init__() | |
self.op_layer = op_layer | |
self.global_config = global_config | |
def normlize(self, weights: nn.Parameter) -> nn.Parameter: | |
"""Normalize weights. | |
Args: | |
weights (nn.Parameter): Weights to be normalized. | |
Returns: | |
nn.Parameters: Normalized weights. | |
""" | |
abs_weights = torch.abs(weights) | |
normalized_weights = abs_weights / torch.sum(abs_weights) | |
return normalized_weights | |
class Conv2dRFSearchOp(BaseConvRFSearchOp): | |
"""Enable Conv2d with receptive field searching ability. | |
Args: | |
op_layer (nn.Module): pytorch module, e,g, Conv2d | |
global_config (dict): config dict. Defaults to None. | |
By default this must include: | |
- "init_alphas": The value for initializing weights of each branch. | |
- "num_branches": The controller of the size of | |
search space (the number of branches). | |
- "exp_rate": The controller of the sparsity of search space. | |
- "mmin": The minimum dilation rate. | |
- "mmax": The maximum dilation rate. | |
Extra keys may exist, but are used by RFSearchHook, e.g., "step", | |
"max_step", "search_interval", and "skip_layer". | |
verbose (bool): Determines whether to print rf-next | |
related logging messages. | |
Defaults to True. | |
""" | |
def __init__(self, | |
op_layer: nn.Module, | |
global_config: dict, | |
verbose: bool = True): | |
super().__init__(op_layer, global_config) | |
assert global_config is not None, 'global_config is None' | |
self.num_branches = global_config['num_branches'] | |
assert self.num_branches in [2, 3] | |
self.verbose = verbose | |
init_dilation = op_layer.dilation | |
self.dilation_rates = expand_rates(init_dilation, global_config) | |
if self.op_layer.kernel_size[ | |
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: | |
self.dilation_rates = [(op_layer.dilation[0], r[1]) | |
for r in self.dilation_rates] | |
if self.op_layer.kernel_size[ | |
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: | |
self.dilation_rates = [(r[0], op_layer.dilation[1]) | |
for r in self.dilation_rates] | |
self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches)) | |
if self.verbose: | |
print_log(f'Expand as {self.dilation_rates}', 'current') | |
nn.init.constant_(self.branch_weights, global_config['init_alphas']) | |
def forward(self, input: Tensor) -> Tensor: | |
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) | |
if len(self.dilation_rates) == 1: | |
outputs = [ | |
nn.functional.conv2d( | |
input, | |
weight=self.op_layer.weight, | |
bias=self.op_layer.bias, | |
stride=self.op_layer.stride, | |
padding=self.get_padding(self.dilation_rates[0]), | |
dilation=self.dilation_rates[0], | |
groups=self.op_layer.groups, | |
) | |
] | |
else: | |
outputs = [ | |
nn.functional.conv2d( | |
input, | |
weight=self.op_layer.weight, | |
bias=self.op_layer.bias, | |
stride=self.op_layer.stride, | |
padding=self.get_padding(r), | |
dilation=r, | |
groups=self.op_layer.groups, | |
) * norm_w[i] for i, r in enumerate(self.dilation_rates) | |
] | |
output = outputs[0] | |
for i in range(1, len(self.dilation_rates)): | |
output += outputs[i] | |
return output | |
def estimate_rates(self) -> None: | |
"""Estimate new dilation rate based on trained branch_weights.""" | |
norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) | |
if self.verbose: | |
print_log( | |
'Estimate dilation {} with weight {}.'.format( | |
self.dilation_rates, | |
norm_w.detach().cpu().numpy().tolist()), 'current') | |
sum0, sum1, w_sum = 0, 0, 0 | |
for i in range(len(self.dilation_rates)): | |
sum0 += norm_w[i].item() * self.dilation_rates[i][0] | |
sum1 += norm_w[i].item() * self.dilation_rates[i][1] | |
w_sum += norm_w[i].item() | |
estimated = [ | |
np.clip( | |
int(round(sum0 / w_sum)), self.global_config['mmin'], | |
self.global_config['mmax']).item(), | |
np.clip( | |
int(round(sum1 / w_sum)), self.global_config['mmin'], | |
self.global_config['mmax']).item() | |
] | |
self.op_layer.dilation = tuple(estimated) | |
self.op_layer.padding = self.get_padding(self.op_layer.dilation) | |
self.dilation_rates = [tuple(estimated)] | |
if self.verbose: | |
print_log(f'Estimate as {tuple(estimated)}', 'current') | |
def expand_rates(self) -> None: | |
"""Expand dilation rate.""" | |
dilation = self.op_layer.dilation | |
dilation_rates = expand_rates(dilation, self.global_config) | |
if self.op_layer.kernel_size[ | |
0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: | |
dilation_rates = [(dilation[0], r[1]) for r in dilation_rates] | |
if self.op_layer.kernel_size[ | |
1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: | |
dilation_rates = [(r[0], dilation[1]) for r in dilation_rates] | |
self.dilation_rates = copy.deepcopy(dilation_rates) | |
if self.verbose: | |
print_log(f'Expand as {self.dilation_rates}', 'current') | |
nn.init.constant_(self.branch_weights, | |
self.global_config['init_alphas']) | |
def get_padding(self, dilation) -> tuple: | |
padding = (get_single_padding(self.op_layer.kernel_size[0], | |
self.op_layer.stride[0], dilation[0]), | |
get_single_padding(self.op_layer.kernel_size[1], | |
self.op_layer.stride[1], dilation[1])) | |
return padding | |