vae_celeba / disvae /utils /initialization.py
Jonas Becker
1st try
c53ddec
import torch
from torch import nn
def get_activation_name(activation):
"""Given a string or a `torch.nn.modules.activation` return the name of the activation."""
if isinstance(activation, str):
return activation
mapper = {nn.LeakyReLU: "leaky_relu", nn.ReLU: "relu", nn.Tanh: "tanh",
nn.Sigmoid: "sigmoid", nn.Softmax: "sigmoid"}
for k, v in mapper.items():
if isinstance(activation, k):
return k
raise ValueError("Unkown given activation type : {}".format(activation))
def get_gain(activation):
"""Given an object of `torch.nn.modules.activation` or an activation name
return the correct gain."""
if activation is None:
return 1
activation_name = get_activation_name(activation)
param = None if activation_name != "leaky_relu" else activation.negative_slope
gain = nn.init.calculate_gain(activation_name, param)
return gain
def linear_init(layer, activation="relu"):
"""Initialize a linear layer.
Args:
layer (nn.Linear): parameters to initialize.
activation (`torch.nn.modules.activation` or str, optional) activation that
will be used on the `layer`.
"""
x = layer.weight
if activation is None:
return nn.init.xavier_uniform_(x)
activation_name = get_activation_name(activation)
if activation_name == "leaky_relu":
a = 0 if isinstance(activation, str) else activation.negative_slope
return nn.init.kaiming_uniform_(x, a=a, nonlinearity='leaky_relu')
elif activation_name == "relu":
return nn.init.kaiming_uniform_(x, nonlinearity='relu')
elif activation_name in ["sigmoid", "tanh"]:
return nn.init.xavier_uniform_(x, gain=get_gain(activation))
def weights_init(module):
if isinstance(module, torch.nn.modules.conv._ConvNd):
# TO-DO: check litterature
linear_init(module)
elif isinstance(module, nn.Linear):
linear_init(module)