File size: 5,417 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# Copyright 2021 Xuankai Chang
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder definition."""
import contextlib
import copy
from filelock import FileLock
import logging
import os
from typing import Optional
from typing import Tuple
import torch
from typeguard import check_argument_types
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet2.asr.encoder.abs_encoder import AbsEncoder
class FairSeqWav2Vec2Encoder(AbsEncoder):
"""FairSeq Wav2Vec2 encoder module.
Args:
input_size: input dim
output_size: dimension of attention
w2v_url: url to Wav2Vec2.0 pretrained model
w2v_dir_path: directory to download the Wav2Vec2.0 pretrained model.
normalize_before: whether to use layer_norm before the first block
finetune_last_n_layers: last n layers to be finetuned in Wav2Vec2.0
0 means to finetune every layer if freeze_w2v=False.
"""
def __init__(
self,
input_size: int,
w2v_url: str,
w2v_dir_path: str = "./",
output_size: int = 256,
normalize_before: bool = False,
freeze_finetune_updates: int = 0,
):
assert check_argument_types()
super().__init__()
if w2v_url != "":
try:
import fairseq
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
except Exception as e:
print("Error: FairSeq is not properly installed.")
print(
"Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done"
)
raise e
self.w2v_model_path = download_w2v(w2v_url, w2v_dir_path)
self._output_size = output_size
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[self.w2v_model_path],
arg_overrides={"data": w2v_dir_path},
)
model = models[0]
if not isinstance(model, Wav2Vec2Model):
try:
model = model.w2v_encoder.w2v_model
except Exception as e:
print(
"Error: pretrained models should be within: "
"'Wav2Vec2Model, Wav2VecCTC' classes, etc."
)
raise e
self.encoders = model
self.pretrained_params = copy.deepcopy(model.state_dict())
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
if model.cfg.encoder_embed_dim != output_size:
# TODO(xkc09): try LSTM
self.output_layer = torch.nn.Sequential(
torch.nn.Linear(model.cfg.encoder_embed_dim, output_size),
)
else:
self.output_layer = None
self.freeze_finetune_updates = freeze_finetune_updates
self.register_buffer("num_updates", torch.LongTensor([0]))
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Forward FairSeqWav2Vec2 Encoder.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = make_pad_mask(ilens).to(xs_pad.device)
ft = self.freeze_finetune_updates <= self.num_updates
if self.num_updates <= self.freeze_finetune_updates:
self.num_updates += 1
elif ft and self.num_updates == self.freeze_finetune_updates + 1:
self.num_updates += 1
logging.info("Start fine-tuning wav2vec parameters!")
with torch.no_grad() if not ft else contextlib.nullcontext():
enc_outputs = self.encoders(
xs_pad,
masks,
features_only=True,
)
xs_pad = enc_outputs["x"] # (B,T,C),
masks = enc_outputs["padding_mask"] # (B, T)
olens = (~masks).sum(dim=1)
if self.output_layer is not None:
xs_pad = self.output_layer(xs_pad)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
return xs_pad, olens, None
def reload_pretrained_parameters(self):
self.encoders.load_state_dict(self.pretrained_params)
logging.info("Pretrained Wav2Vec model parameters reloaded!")
def download_w2v(model_url, dir_path):
os.makedirs(dir_path, exist_ok=True)
model_name = model_url.split("/")[-1]
model_path = os.path.join(dir_path, model_name)
dict_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt"
dict_path = os.path.join(dir_path, dict_url.split("/")[-1])
with FileLock(model_path + ".lock"):
if not os.path.exists(model_path):
torch.hub.download_url_to_file(model_url, model_path)
torch.hub.download_url_to_file(dict_url, dict_path)
logging.info(f"Wav2Vec model downloaded {model_path}")
else:
logging.info(f"Wav2Vec model {model_path} already exists.")
return model_path
|