Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import inspect | |
from typing import Dict | |
import torch.nn as nn | |
from mmengine.registry import MODELS | |
MODELS.register_module('zero', module=nn.ZeroPad2d) | |
MODELS.register_module('reflect', module=nn.ReflectionPad2d) | |
MODELS.register_module('replicate', module=nn.ReplicationPad2d) | |
def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: | |
"""Build padding layer. | |
Args: | |
cfg (dict): The padding layer config, which should contain: | |
- type (str): Layer type. | |
- layer args: Args needed to instantiate a padding layer. | |
Returns: | |
nn.Module: Created padding layer. | |
""" | |
if not isinstance(cfg, dict): | |
raise TypeError('cfg must be a dict') | |
if 'type' not in cfg: | |
raise KeyError('the cfg dict must contain the key "type"') | |
cfg_ = cfg.copy() | |
padding_type = cfg_.pop('type') | |
if inspect.isclass(padding_type): | |
return padding_type(*args, **kwargs, **cfg_) | |
# Switch registry to the target scope. If `padding_layer` cannot be found | |
# in the registry, fallback to search `padding_layer` in the | |
# mmengine.MODELS. | |
with MODELS.switch_scope_and_registry(None) as registry: | |
padding_layer = registry.get(padding_type) | |
if padding_layer is None: | |
raise KeyError(f'Cannot find {padding_layer} in registry under scope ' | |
f'name {registry.scope}') | |
layer = padding_layer(*args, **kwargs, **cfg_) | |
return layer | |