# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import inspect from typing import Dict import torch.nn as nn from mmengine.registry import MODELS MODELS.register_module('zero', module=nn.ZeroPad2d) MODELS.register_module('reflect', module=nn.ReflectionPad2d) MODELS.register_module('replicate', module=nn.ReplicationPad2d) def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: """Build padding layer. Args: cfg (dict): The padding layer config, which should contain: - type (str): Layer type. - layer args: Args needed to instantiate a padding layer. Returns: nn.Module: Created padding layer. """ if not isinstance(cfg, dict): raise TypeError('cfg must be a dict') if 'type' not in cfg: raise KeyError('the cfg dict must contain the key "type"') cfg_ = cfg.copy() padding_type = cfg_.pop('type') if inspect.isclass(padding_type): return padding_type(*args, **kwargs, **cfg_) # Switch registry to the target scope. If `padding_layer` cannot be found # in the registry, fallback to search `padding_layer` in the # mmengine.MODELS. with MODELS.switch_scope_and_registry(None) as registry: padding_layer = registry.get(padding_type) if padding_layer is None: raise KeyError(f'Cannot find {padding_layer} in registry under scope ' f'name {registry.scope}') layer = padding_layer(*args, **kwargs, **cfg_) return layer