wav2vec2 / src /utils /functional.py
hoang1007
init
5381499
raw
history blame contribute delete
No virus
1.03 kB
import torch
def init_module_weights(module):
"""Initialize the weights"""
from src.model.modules import QuantizationModule
# gumbel softmax requires special init
if isinstance(module, QuantizationModule):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
torch.nn.init.uniform_(module.codebooks)
elif isinstance(module, torch.nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=0.5)
elif isinstance(module, (torch.nn.LayerNorm, torch.nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, torch.nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight.data)
if (
isinstance(module, (torch.nn.Linear, torch.nn.Conv1d))
and module.bias is not None
):
module.bias.data.zero_()