TomatoCocotree
上传
6a62ffb
from typing import Callable
import torch
from torch import zero_
from torch.nn import Module
from torch.nn.init import kaiming_normal_, xavier_normal_, normal_
def create_init_function(method: str = 'none') -> Callable[[Module], Module]:
def init(module: Module):
if method == 'none':
return module
elif method == 'he':
kaiming_normal_(module.weight)
return module
elif method == 'xavier':
xavier_normal_(module.weight)
return module
elif method == 'dcgan':
normal_(module.weight, 0.0, 0.02)
return module
elif method == 'dcgan_001':
normal_(module.weight, 0.0, 0.01)
return module
elif method == "zero":
with torch.no_grad():
zero_(module.weight)
return module
else:
raise ("Invalid initialization method %s" % method)
return init
class HeInitialization:
def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):
self.nonlinearity = nonlinearity
self.mode = mode
self.a = a
def __call__(self, module: Module) -> Module:
with torch.no_grad():
kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
return module
class NormalInitialization:
def __init__(self, mean: float = 0.0, std: float = 1.0):
self.std = std
self.mean = mean
def __call__(self, module: Module) -> Module:
with torch.no_grad():
normal_(module.weight, self.mean, self.std)
return module
class XavierInitialization:
def __init__(self, gain: float = 1.0):
self.gain = gain
def __call__(self, module: Module) -> Module:
with torch.no_grad():
xavier_normal_(module.weight, self.gain)
return module
class ZeroInitialization:
def __call__(self, module: Module) -> Module:
with torch.no_grad:
zero_(module.weight)
return module
class NoInitialization:
def __call__(self, module: Module) -> Module:
return module