# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import inspect from typing import Dict, Tuple, Union import torch.nn as nn from mmengine.registry import MODELS from mmengine.utils import is_tuple_of from mmengine.utils.dl_utils.parrots_wrapper import (SyncBatchNorm, _BatchNorm, _InstanceNorm) MODELS.register_module('BN', module=nn.BatchNorm2d) MODELS.register_module('BN1d', module=nn.BatchNorm1d) MODELS.register_module('BN2d', module=nn.BatchNorm2d) MODELS.register_module('BN3d', module=nn.BatchNorm3d) MODELS.register_module('SyncBN', module=SyncBatchNorm) MODELS.register_module('GN', module=nn.GroupNorm) MODELS.register_module('LN', module=nn.LayerNorm) MODELS.register_module('IN', module=nn.InstanceNorm2d) MODELS.register_module('IN1d', module=nn.InstanceNorm1d) MODELS.register_module('IN2d', module=nn.InstanceNorm2d) MODELS.register_module('IN3d', module=nn.InstanceNorm3d) def infer_abbr(class_type): """Infer abbreviation from the class name. When we build a norm layer with `build_norm_layer()`, we want to preserve the norm type in variable names, e.g, self.bn1, self.gn. This method will infer the abbreviation to map class types to abbreviations. Rule 1: If the class has the property "_abbr_", return the property. Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and "in" respectively. Rule 3: If the class name contains "batch", "group", "layer" or "instance", the abbreviation of this layer will be "bn", "gn", "ln" and "in" respectively. Rule 4: Otherwise, the abbreviation falls back to "norm". Args: class_type (type): The norm layer type. Returns: str: The inferred abbreviation. """ if not inspect.isclass(class_type): raise TypeError( f'class_type must be a type, but got {type(class_type)}') if hasattr(class_type, '_abbr_'): return class_type._abbr_ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN return 'in' elif issubclass(class_type, _BatchNorm): return 'bn' elif issubclass(class_type, nn.GroupNorm): return 'gn' elif issubclass(class_type, nn.LayerNorm): return 'ln' else: class_name = class_type.__name__.lower() if 'batch' in class_name: return 'bn' elif 'group' in class_name: return 'gn' elif 'layer' in class_name: return 'ln' elif 'instance' in class_name: return 'in' else: return 'norm_layer' def build_norm_layer(cfg: Dict, num_features: int, postfix: Union[int, str] = '') -> Tuple[str, nn.Module]: """Build normalization layer. Args: cfg (dict): The norm layer config, which should contain: - type (str): Layer type. - layer args: Args needed to instantiate a norm layer. - requires_grad (bool, optional): Whether stop gradient updates. num_features (int): Number of input channels. postfix (int | str): The postfix to be appended into norm abbreviation to create named layer. Returns: tuple[str, nn.Module]: The first element is the layer name consisting of abbreviation and postfix, e.g., bn1, gn. The second element is the created norm layer. """ if not isinstance(cfg, dict): raise TypeError('cfg must be a dict') if 'type' not in cfg: raise KeyError('the cfg dict must contain the key "type"') cfg_ = cfg.copy() layer_type = cfg_.pop('type') if inspect.isclass(layer_type): norm_layer = layer_type else: # Switch registry to the target scope. If `norm_layer` cannot be found # in the registry, fallback to search `norm_layer` in the # mmengine.MODELS. with MODELS.switch_scope_and_registry(None) as registry: norm_layer = registry.get(layer_type) if norm_layer is None: raise KeyError(f'Cannot find {norm_layer} in registry under ' f'scope name {registry.scope}') abbr = infer_abbr(norm_layer) assert isinstance(postfix, (int, str)) name = abbr + str(postfix) requires_grad = cfg_.pop('requires_grad', True) cfg_.setdefault('eps', 1e-5) if norm_layer is not nn.GroupNorm: layer = norm_layer(num_features, **cfg_) if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): layer._specify_ddp_gpu_num(1) else: assert 'num_groups' in cfg_ layer = norm_layer(num_channels=num_features, **cfg_) for param in layer.parameters(): param.requires_grad = requires_grad return name, layer def is_norm(layer: nn.Module, exclude: Union[type, tuple, None] = None) -> bool: """Check if a layer is a normalization layer. Args: layer (nn.Module): The layer to be checked. exclude (type | tuple[type]): Types to be excluded. Returns: bool: Whether the layer is a norm layer. """ if exclude is not None: if not isinstance(exclude, tuple): exclude = (exclude, ) if not is_tuple_of(exclude, type): raise TypeError( f'"exclude" must be either None or type or a tuple of types, ' f'but got {type(exclude)}: {exclude}') if exclude and isinstance(layer, exclude): return False all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) return isinstance(layer, all_norm_bases)