Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import inspect | |
import torch | |
from ...utils import Registry, build_from_cfg | |
OPTIMIZERS = Registry('optimizer') | |
OPTIMIZER_BUILDERS = Registry('optimizer builder') | |
def register_torch_optimizers(): | |
torch_optimizers = [] | |
for module_name in dir(torch.optim): | |
if module_name.startswith('__'): | |
continue | |
_optim = getattr(torch.optim, module_name) | |
if inspect.isclass(_optim) and issubclass(_optim, | |
torch.optim.Optimizer): | |
OPTIMIZERS.register_module()(_optim) | |
torch_optimizers.append(module_name) | |
return torch_optimizers | |
TORCH_OPTIMIZERS = register_torch_optimizers() | |
def build_optimizer_constructor(cfg): | |
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) | |
def build_optimizer(model, cfg): | |
optimizer_cfg = copy.deepcopy(cfg) | |
constructor_type = optimizer_cfg.pop('constructor', | |
'DefaultOptimizerConstructor') | |
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) | |
optim_constructor = build_optimizer_constructor( | |
dict( | |
type=constructor_type, | |
optimizer_cfg=optimizer_cfg, | |
paramwise_cfg=paramwise_cfg)) | |
optimizer = optim_constructor(model) | |
return optimizer | |