rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
5.87 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Dict, Tuple, Union
import torch.nn as nn
from mmengine.registry import MODELS
from mmengine.utils import is_tuple_of
from mmengine.utils.dl_utils.parrots_wrapper import (SyncBatchNorm, _BatchNorm,
_InstanceNorm)
MODELS.register_module('BN', module=nn.BatchNorm2d)
MODELS.register_module('BN1d', module=nn.BatchNorm1d)
MODELS.register_module('BN2d', module=nn.BatchNorm2d)
MODELS.register_module('BN3d', module=nn.BatchNorm3d)
MODELS.register_module('SyncBN', module=SyncBatchNorm)
MODELS.register_module('GN', module=nn.GroupNorm)
MODELS.register_module('LN', module=nn.LayerNorm)
MODELS.register_module('IN', module=nn.InstanceNorm2d)
MODELS.register_module('IN1d', module=nn.InstanceNorm1d)
MODELS.register_module('IN2d', module=nn.InstanceNorm2d)
MODELS.register_module('IN3d', module=nn.InstanceNorm3d)
def infer_abbr(class_type):
"""Infer abbreviation from the class name.
When we build a norm layer with `build_norm_layer()`, we want to preserve
the norm type in variable names, e.g, self.bn1, self.gn. This method will
infer the abbreviation to map class types to abbreviations.
Rule 1: If the class has the property "_abbr_", return the property.
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
"in" respectively.
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
respectively.
Rule 4: Otherwise, the abbreviation falls back to "norm".
Args:
class_type (type): The norm layer type.
Returns:
str: The inferred abbreviation.
"""
if not inspect.isclass(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'):
return class_type._abbr_
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
return 'in'
elif issubclass(class_type, _BatchNorm):
return 'bn'
elif issubclass(class_type, nn.GroupNorm):
return 'gn'
elif issubclass(class_type, nn.LayerNorm):
return 'ln'
else:
class_name = class_type.__name__.lower()
if 'batch' in class_name:
return 'bn'
elif 'group' in class_name:
return 'gn'
elif 'layer' in class_name:
return 'ln'
elif 'instance' in class_name:
return 'in'
else:
return 'norm_layer'
def build_norm_layer(cfg: Dict,
num_features: int,
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
"""Build normalization layer.
Args:
cfg (dict): The norm layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a norm layer.
- requires_grad (bool, optional): Whether stop gradient updates.
num_features (int): Number of input channels.
postfix (int | str): The postfix to be appended into norm abbreviation
to create named layer.
Returns:
tuple[str, nn.Module]: The first element is the layer name consisting
of abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer.
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
norm_layer = layer_type
else:
# Switch registry to the target scope. If `norm_layer` cannot be found
# in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under '
f'scope name {registry.scope}')
abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5)
if norm_layer is not nn.GroupNorm:
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1)
else:
assert 'num_groups' in cfg_
layer = norm_layer(num_channels=num_features, **cfg_)
for param in layer.parameters():
param.requires_grad = requires_grad
return name, layer
def is_norm(layer: nn.Module,
exclude: Union[type, tuple, None] = None) -> bool:
"""Check if a layer is a normalization layer.
Args:
layer (nn.Module): The layer to be checked.
exclude (type | tuple[type]): Types to be excluded.
Returns:
bool: Whether the layer is a norm layer.
"""
if exclude is not None:
if not isinstance(exclude, tuple):
exclude = (exclude, )
if not is_tuple_of(exclude, type):
raise TypeError(
f'"exclude" must be either None or type or a tuple of types, '
f'but got {type(exclude)}: {exclude}')
if exclude and isinstance(layer, exclude):
return False
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)