Spaces:
Running
Running
File size: 6,599 Bytes
03f6091 a005919 03f6091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
# -*- coding: utf-8 -*-
r"""
XLM-R Encoder Model
====================
Pretrained XLM-RoBERTa from Fairseq framework.
https://github.com/pytorch/fairseq/tree/master/examples/xlmr
"""
import os
from argparse import Namespace
from typing import Dict
import torch
from polos.models.encoders.encoder_base import Encoder
from polos.tokenizers_ import XLMRTextEncoder
from fairseq.models.roberta import XLMRModel
from torchnlp.download import download_file_maybe_extract
from torchnlp.utils import lengths_to_mask
XLMR_LARGE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz"
XLMR_LARGE_MODEL_NAME = "xlmr.large/model.pt"
XLMR_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz"
XLMR_BASE_MODEL_NAME = "xlmr.base/model.pt"
XLMR_LARGE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz"
XLMR_LARGE_V0_MODEL_NAME = "xlmr.large.v0/model.pt"
XLMR_BASE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz"
XLMR_BASE_V0_MODEL_NAME = "xlmr.base.v0/model.pt"
saving_directory = "./.cache/"
class XLMREncoder(Encoder):
"""
XLM-RoBERTa encoder from Fairseq.
:param xlmr: XLM-R model to be used.
:param tokenizer: XLM-R model tokenizer to be used.
:param hparams: Namespace.
"""
def __init__(
self, xlmr: XLMRModel, tokenizer: XLMRTextEncoder, hparams: Namespace
) -> None:
super().__init__(tokenizer)
self._output_units = 768 if "base" in hparams.pretrained_model else 1024
self._n_layers = 13 if "base" in hparams.pretrained_model else 25
self._max_pos = 512
# Save some meory by removing the LM and classification heads
# xlmr.model.decoder.lm_head.dense = None
# xlmr.model.decoder.classification_heads = None
self.model = xlmr
def freeze_embeddings(self) -> None:
""" Freezes the embedding layer of the network to save some memory while training. """
for (
param
) in self.model.model.decoder.sentence_encoder.embed_tokens.parameters():
param.requires_grad = False
for (
param
) in self.model.model.decoder.sentence_encoder.embed_positions.parameters():
param.requires_grad = False
for (
param
) in self.model.model.decoder.sentence_encoder.emb_layer_norm.parameters():
param.requires_grad = False
def layerwise_lr(self, lr: float, decay: float):
"""
:return: List with grouped model parameters with layer-wise decaying learning rate
"""
# Embedding layer
opt_parameters = [
{
"params": self.model.model.decoder.sentence_encoder.embed_tokens.parameters(),
"lr": lr * decay ** (self.num_layers),
},
{
"params": self.model.model.decoder.sentence_encoder.embed_positions.parameters(),
"lr": lr * decay ** (self.num_layers),
},
{
"params": self.model.model.decoder.sentence_encoder.emb_layer_norm.parameters(),
"lr": lr * decay ** (self.num_layers),
},
]
# Layer wise parameters
opt_parameters += [
{
"params": self.model.model.decoder.sentence_encoder.layers[
l
].parameters(),
"lr": lr * decay ** (self.num_layers - 1 - l),
}
for l in range(self.num_layers - 1)
]
# Language Model Head parameters
opt_parameters += [
{
"params": self.model.model.decoder.lm_head.layer_norm.parameters(),
"lr": lr,
},
{"params": self.model.model.decoder.lm_head.dense.parameters(), "lr": lr},
]
return opt_parameters
@property
def lm_head(self):
""" Language modeling head. """
return self.model.model.decoder.lm_head
@classmethod
def from_pretrained(cls, hparams: Namespace):
if not os.path.exists(saving_directory):
os.makedirs(saving_directory)
pretrained_model = hparams.pretrained_model
if pretrained_model == "xlmr.base":
download_file_maybe_extract(
XLMR_BASE_URL,
directory=saving_directory,
check_files=[XLMR_BASE_MODEL_NAME],
)
elif pretrained_model == "xlmr.large":
download_file_maybe_extract(
XLMR_LARGE_URL,
directory=saving_directory,
check_files=[XLMR_LARGE_MODEL_NAME],
)
elif pretrained_model == "xlmr.base.v0":
download_file_maybe_extract(
XLMR_BASE_V0_URL,
directory=saving_directory,
check_files=[XLMR_BASE_V0_MODEL_NAME],
)
elif pretrained_model == "xlmr.large.v0":
download_file_maybe_extract(
XLMR_LARGE_V0_URL,
directory=saving_directory,
check_files=[XLMR_LARGE_V0_MODEL_NAME],
)
else:
raise Exception(f"{pretrained_model} is an invalid XLM-R model.")
xlmr = XLMRModel.from_pretrained(
saving_directory + pretrained_model, checkpoint_file="model.pt"
)
# xlmr.eval()
tokenizer = XLMRTextEncoder(
xlmr.encode, xlmr.task.source_dictionary.__dict__["indices"]
)
return XLMREncoder(xlmr=xlmr, tokenizer=tokenizer, hparams=hparams)
def forward(
self, tokens: torch.Tensor, lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Encodes a batch of sequences.
:param tokens: Torch tensor with the input sequences [batch_size x seq_len].
:param lengths: Torch tensor with the length of each sequence [seq_len].
:return: Dictionary with `sentemb` (tensor with dims [batch_size x output_units]), `wordemb`
(tensor with dims [batch_size x seq_len x output_units]), `mask` (input mask),
`all_layers` (List with word_embeddings from all layers), `extra` (tuple with all XLM-R layers).
"""
mask = lengths_to_mask(lengths, device=tokens.device)
all_layers = self.model.extract_features(tokens, return_all_hiddens=True)
return {
"sentemb": all_layers[-1][:, 0, :],
"wordemb": all_layers[-1],
"all_layers": all_layers,
"mask": mask,
"extra": (all_layers),
}
|