# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os from typing import Dict, Optional import mmengine import torch # noqa import torch.nn as nn from mmengine.hooks import Hook from mmengine.logging import print_log from mmengine.registry import HOOKS from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp # noqa from .utils import get_single_padding, write_to_json @HOOKS.register_module() class RFSearchHook(Hook): """Rcecptive field search via dilation rates. Please refer to `RF-Next: Efficient Receptive Field Search for Convolutional Neural Networks `_ for more details. Args: mode (str, optional): It can be set to the following types: 'search', 'fixed_single_branch', or 'fixed_multi_branch'. Defaults to 'search'. config (Dict, optional): config dict of search. By default this config contains "search", and config["search"] must include: - "step": recording the current searching step. - "max_step": The maximum number of searching steps to update the structures. - "search_interval": The interval (epoch/iteration) between two updates. - "exp_rate": The controller of the sparsity of search space. - "init_alphas": The value for initializing weights of each branch. - "mmin": The minimum dilation rate. - "mmax": The maximum dilation rate. - "num_branches": The controller of the size of search space (the number of branches). - "skip_layer": The modules in skip_layer will be ignored during the receptive field search. rfstructure_file (str, optional): Path to load searched receptive fields of the model. Defaults to None. by_epoch (bool, optional): Determine to perform step by epoch or by iteration. If set to True, it will step by epoch. Otherwise, by iteration. Defaults to True. verbose (bool): Determines whether to print rf-next related logging messages. Defaults to True. """ def __init__(self, mode: str = 'search', config: Dict = {}, rfstructure_file: Optional[str] = None, by_epoch: bool = True, verbose: bool = True): assert mode in ['search', 'fixed_single_branch', 'fixed_multi_branch'] assert config is not None self.config = config self.config['structure'] = {} self.verbose = verbose if rfstructure_file is not None: rfstructure = mmengine.load(rfstructure_file)['structure'] self.config['structure'] = rfstructure self.mode = mode self.num_branches = self.config['search']['num_branches'] self.by_epoch = by_epoch def init_model(self, model: nn.Module): """init model with search ability. Args: model (nn.Module): pytorch model Raises: NotImplementedError: only support three modes: search/fixed_single_branch/fixed_multi_branch """ if self.verbose: print_log('RFSearch init begin.', 'current') if self.mode == 'search': if self.config['structure']: self.set_model(model, search_op='Conv2d') self.wrap_model(model, search_op='Conv2d') elif self.mode == 'fixed_single_branch': self.set_model(model, search_op='Conv2d') elif self.mode == 'fixed_multi_branch': self.set_model(model, search_op='Conv2d') self.wrap_model(model, search_op='Conv2d') else: raise NotImplementedError if self.verbose: print_log('RFSearch init end.', 'current') def after_train_epoch(self, runner): """Performs a dilation searching step after one training epoch.""" if self.by_epoch and self.mode == 'search': self.step(runner.model, runner.work_dir) def after_train_iter(self, runner, batch_idx, data_batch, outputs): """Performs a dilation searching step after one training iteration.""" if not self.by_epoch and self.mode == 'search': self.step(runner.model, runner.work_dir) def step(self, model: nn.Module, work_dir: str) -> None: """Performs a dilation searching step. Args: model (nn.Module): pytorch model work_dir (str): Directory to save the searching results. """ self.config['search']['step'] += 1 if (self.config['search']['step'] ) % self.config['search']['search_interval'] == 0 and (self.config[ 'search']['step']) < self.config['search']['max_step']: self.estimate_and_expand(model) for name, module in model.named_modules(): if isinstance(module, BaseConvRFSearchOp): self.config['structure'][name] = module.op_layer.dilation write_to_json( self.config, os.path.join( work_dir, 'local_search_config_step%d.json' % self.config['search']['step'], ), ) def estimate_and_expand(self, model: nn.Module) -> None: """estimate and search for RFConvOp. Args: model (nn.Module): pytorch model """ for module in model.modules(): if isinstance(module, BaseConvRFSearchOp): module.estimate_rates() module.expand_rates() def wrap_model(self, model: nn.Module, search_op: str = 'Conv2d', prefix: str = '') -> None: """wrap model to support searchable conv op. Args: model (nn.Module): pytorch model search_op (str): The module that uses RF search. Defaults to 'Conv2d'. init_rates (int, optional): Set to other initial dilation rates. Defaults to None. prefix (str): Prefix for function recursion. Defaults to ''. """ op = 'torch.nn.' + search_op for name, module in model.named_children(): if prefix == '': fullname = 'module.' + name else: fullname = prefix + '.' + name if self.config['search']['skip_layer'] is not None: if any(layer in fullname for layer in self.config['search']['skip_layer']): continue if isinstance(module, eval(op)): if 1 < module.kernel_size[0] and \ 0 != module.kernel_size[0] % 2 or \ 1 < module.kernel_size[1] and \ 0 != module.kernel_size[1] % 2: moduleWrap = eval(search_op + 'RFSearchOp')( module, self.config['search'], self.verbose) moduleWrap = moduleWrap.to(module.weight.device) if self.verbose: print_log( 'Wrap model %s to %s.' % (str(module), str(moduleWrap)), 'current') setattr(model, name, moduleWrap) elif not isinstance(module, BaseConvRFSearchOp): self.wrap_model(module, search_op, fullname) def set_model(self, model: nn.Module, search_op: str = 'Conv2d', init_rates: Optional[int] = None, prefix: str = '') -> None: """set model based on config. Args: model (nn.Module): pytorch model config (Dict): config file search_op (str): The module that uses RF search. Defaults to 'Conv2d'. init_rates (int, optional): Set to other initial dilation rates. Defaults to None. prefix (str): Prefix for function recursion. Defaults to ''. """ op = 'torch.nn.' + search_op for name, module in model.named_children(): if prefix == '': fullname = 'module.' + name else: fullname = prefix + '.' + name if self.config['search']['skip_layer'] is not None: if any(layer in fullname for layer in self.config['search']['skip_layer']): continue if isinstance(module, eval(op)): if 1 < module.kernel_size[0] and \ 0 != module.kernel_size[0] % 2 or \ 1 < module.kernel_size[1] and \ 0 != module.kernel_size[1] % 2: if isinstance(self.config['structure'][fullname], int): self.config['structure'][fullname] = [ self.config['structure'][fullname], self.config['structure'][fullname] ] module.dilation = ( self.config['structure'][fullname][0], self.config['structure'][fullname][1], ) module.padding = ( get_single_padding( module.kernel_size[0], module.stride[0], self.config['structure'][fullname][0]), get_single_padding( module.kernel_size[1], module.stride[1], self.config['structure'][fullname][1])) setattr(model, name, module) if self.verbose: print_log( 'Set module %s dilation as: [%d %d]' % (fullname, module.dilation[0], module.dilation[1]), 'current') elif not isinstance(module, BaseConvRFSearchOp): self.set_model(module, search_op, init_rates, fullname)