TomatoCocotree
上传
6a62ffb
from typing import Optional, Callable, Union
from torch.nn import Module
from tha3.module.module_factory import ModuleFactory
from tha3.nn.init_function import create_init_function
from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory
from tha3.nn.normalization import NormalizationLayerFactory
from tha3.nn.spectral_norm import apply_spectral_norm
def wrap_conv_or_linear_module(module: Module,
initialization_method: Union[str, Callable[[Module], Module]],
use_spectral_norm: bool):
if isinstance(initialization_method, str):
init = create_init_function(initialization_method)
else:
init = initialization_method
return apply_spectral_norm(init(module), use_spectral_norm)
class BlockArgs:
def __init__(self,
initialization_method: Union[str, Callable[[Module], Module]] = 'he',
use_spectral_norm: bool = False,
normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
nonlinearity_factory: Optional[ModuleFactory] = None):
self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
self.normalization_layer_factory = normalization_layer_factory
self.use_spectral_norm = use_spectral_norm
self.initialization_method = initialization_method
def wrap_module(self, module: Module) -> Module:
return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm)
def get_init_func(self) -> Callable[[Module], Module]:
if isinstance(self.initialization_method, str):
return create_init_function(self.initialization_method)
else:
return self.initialization_method