# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import numpy as np import torch import torch.nn as nn from mmengine.logging import print_log from mmengine.model import BaseModule from torch import Tensor from .utils import expand_rates, get_single_padding class BaseConvRFSearchOp(BaseModule): """Based class of ConvRFSearchOp. Args: op_layer (nn.Module): pytorch module, e,g, Conv2d global_config (dict): config dict. """ def __init__(self, op_layer: nn.Module, global_config: dict): super().__init__() self.op_layer = op_layer self.global_config = global_config def normlize(self, weights: nn.Parameter) -> nn.Parameter: """Normalize weights. Args: weights (nn.Parameter): Weights to be normalized. Returns: nn.Parameters: Normalized weights. """ abs_weights = torch.abs(weights) normalized_weights = abs_weights / torch.sum(abs_weights) return normalized_weights class Conv2dRFSearchOp(BaseConvRFSearchOp): """Enable Conv2d with receptive field searching ability. Args: op_layer (nn.Module): pytorch module, e,g, Conv2d global_config (dict): config dict. Defaults to None. By default this must include: - "init_alphas": The value for initializing weights of each branch. - "num_branches": The controller of the size of search space (the number of branches). - "exp_rate": The controller of the sparsity of search space. - "mmin": The minimum dilation rate. - "mmax": The maximum dilation rate. Extra keys may exist, but are used by RFSearchHook, e.g., "step", "max_step", "search_interval", and "skip_layer". verbose (bool): Determines whether to print rf-next related logging messages. Defaults to True. """ def __init__(self, op_layer: nn.Module, global_config: dict, verbose: bool = True): super().__init__(op_layer, global_config) assert global_config is not None, 'global_config is None' self.num_branches = global_config['num_branches'] assert self.num_branches in [2, 3] self.verbose = verbose init_dilation = op_layer.dilation self.dilation_rates = expand_rates(init_dilation, global_config) if self.op_layer.kernel_size[ 0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: self.dilation_rates = [(op_layer.dilation[0], r[1]) for r in self.dilation_rates] if self.op_layer.kernel_size[ 1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: self.dilation_rates = [(r[0], op_layer.dilation[1]) for r in self.dilation_rates] self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches)) if self.verbose: print_log(f'Expand as {self.dilation_rates}', 'current') nn.init.constant_(self.branch_weights, global_config['init_alphas']) def forward(self, input: Tensor) -> Tensor: norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) if len(self.dilation_rates) == 1: outputs = [ nn.functional.conv2d( input, weight=self.op_layer.weight, bias=self.op_layer.bias, stride=self.op_layer.stride, padding=self.get_padding(self.dilation_rates[0]), dilation=self.dilation_rates[0], groups=self.op_layer.groups, ) ] else: outputs = [ nn.functional.conv2d( input, weight=self.op_layer.weight, bias=self.op_layer.bias, stride=self.op_layer.stride, padding=self.get_padding(r), dilation=r, groups=self.op_layer.groups, ) * norm_w[i] for i, r in enumerate(self.dilation_rates) ] output = outputs[0] for i in range(1, len(self.dilation_rates)): output += outputs[i] return output def estimate_rates(self) -> None: """Estimate new dilation rate based on trained branch_weights.""" norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) if self.verbose: print_log( 'Estimate dilation {} with weight {}.'.format( self.dilation_rates, norm_w.detach().cpu().numpy().tolist()), 'current') sum0, sum1, w_sum = 0, 0, 0 for i in range(len(self.dilation_rates)): sum0 += norm_w[i].item() * self.dilation_rates[i][0] sum1 += norm_w[i].item() * self.dilation_rates[i][1] w_sum += norm_w[i].item() estimated = [ np.clip( int(round(sum0 / w_sum)), self.global_config['mmin'], self.global_config['mmax']).item(), np.clip( int(round(sum1 / w_sum)), self.global_config['mmin'], self.global_config['mmax']).item() ] self.op_layer.dilation = tuple(estimated) self.op_layer.padding = self.get_padding(self.op_layer.dilation) self.dilation_rates = [tuple(estimated)] if self.verbose: print_log(f'Estimate as {tuple(estimated)}', 'current') def expand_rates(self) -> None: """Expand dilation rate.""" dilation = self.op_layer.dilation dilation_rates = expand_rates(dilation, self.global_config) if self.op_layer.kernel_size[ 0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: dilation_rates = [(dilation[0], r[1]) for r in dilation_rates] if self.op_layer.kernel_size[ 1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: dilation_rates = [(r[0], dilation[1]) for r in dilation_rates] self.dilation_rates = copy.deepcopy(dilation_rates) if self.verbose: print_log(f'Expand as {self.dilation_rates}', 'current') nn.init.constant_(self.branch_weights, self.global_config['init_alphas']) def get_padding(self, dilation) -> tuple: padding = (get_single_padding(self.op_layer.kernel_size[0], self.op_layer.stride[0], dilation[0]), get_single_padding(self.op_layer.kernel_size[1], self.op_layer.stride[1], dilation[1])) return padding