|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
|
|
|
|
def weight_init(m): |
|
''' |
|
Usage: |
|
model = Model() |
|
model.apply(weight_init) |
|
''' |
|
if isinstance(m, nn.Conv1d): |
|
init.normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
elif isinstance(m, nn.Conv2d): |
|
init.xavier_normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
elif isinstance(m, nn.Conv3d): |
|
init.xavier_normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
elif isinstance(m, nn.ConvTranspose1d): |
|
init.normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
elif isinstance(m, nn.ConvTranspose2d): |
|
init.xavier_normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
elif isinstance(m, nn.ConvTranspose3d): |
|
init.xavier_normal_(m.weight.data) |
|
if m.bias is not None: |
|
init.normal_(m.bias.data) |
|
elif isinstance(m, nn.BatchNorm1d): |
|
init.normal_(m.weight.data, mean=1, std=0.02) |
|
init.constant_(m.bias.data, 0) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.normal_(m.weight.data, mean=1, std=0.02) |
|
init.constant_(m.bias.data, 0) |
|
elif isinstance(m, nn.BatchNorm3d): |
|
init.normal_(m.weight.data, mean=1, std=0.02) |
|
init.constant_(m.bias.data, 0) |
|
elif isinstance(m, nn.Linear): |
|
init.xavier_normal_(m.weight.data) |
|
init.normal_(m.bias.data) |
|
elif isinstance(m, nn.LSTM): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data) |
|
elif isinstance(m, nn.LSTMCell): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data) |
|
elif isinstance(m, nn.GRU): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data) |
|
elif isinstance(m, nn.GRUCell): |
|
for param in m.parameters(): |
|
if len(param.shape) >= 2: |
|
init.orthogonal_(param.data) |
|
else: |
|
init.normal_(param.data) |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |