# -*- coding: utf-8 -*- r""" XLM-R Encoder Model ==================== Pretrained XLM-RoBERTa from Fairseq framework. https://github.com/pytorch/fairseq/tree/master/examples/xlmr """ import os from argparse import Namespace from typing import Dict import torch from polos.models.encoders.encoder_base import Encoder from polos.tokenizers_ import XLMRTextEncoder from fairseq.models.roberta import XLMRModel from torchnlp.download import download_file_maybe_extract from torchnlp.utils import lengths_to_mask XLMR_LARGE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz" XLMR_LARGE_MODEL_NAME = "xlmr.large/model.pt" XLMR_BASE_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz" XLMR_BASE_MODEL_NAME = "xlmr.base/model.pt" XLMR_LARGE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz" XLMR_LARGE_V0_MODEL_NAME = "xlmr.large.v0/model.pt" XLMR_BASE_V0_URL = "https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz" XLMR_BASE_V0_MODEL_NAME = "xlmr.base.v0/model.pt" saving_directory = "./.cache/" class XLMREncoder(Encoder): """ XLM-RoBERTa encoder from Fairseq. :param xlmr: XLM-R model to be used. :param tokenizer: XLM-R model tokenizer to be used. :param hparams: Namespace. """ def __init__( self, xlmr: XLMRModel, tokenizer: XLMRTextEncoder, hparams: Namespace ) -> None: super().__init__(tokenizer) self._output_units = 768 if "base" in hparams.pretrained_model else 1024 self._n_layers = 13 if "base" in hparams.pretrained_model else 25 self._max_pos = 512 # Save some meory by removing the LM and classification heads # xlmr.model.decoder.lm_head.dense = None # xlmr.model.decoder.classification_heads = None self.model = xlmr def freeze_embeddings(self) -> None: """ Freezes the embedding layer of the network to save some memory while training. """ for ( param ) in self.model.model.decoder.sentence_encoder.embed_tokens.parameters(): param.requires_grad = False for ( param ) in self.model.model.decoder.sentence_encoder.embed_positions.parameters(): param.requires_grad = False for ( param ) in self.model.model.decoder.sentence_encoder.emb_layer_norm.parameters(): param.requires_grad = False def layerwise_lr(self, lr: float, decay: float): """ :return: List with grouped model parameters with layer-wise decaying learning rate """ # Embedding layer opt_parameters = [ { "params": self.model.model.decoder.sentence_encoder.embed_tokens.parameters(), "lr": lr * decay ** (self.num_layers), }, { "params": self.model.model.decoder.sentence_encoder.embed_positions.parameters(), "lr": lr * decay ** (self.num_layers), }, { "params": self.model.model.decoder.sentence_encoder.emb_layer_norm.parameters(), "lr": lr * decay ** (self.num_layers), }, ] # Layer wise parameters opt_parameters += [ { "params": self.model.model.decoder.sentence_encoder.layers[ l ].parameters(), "lr": lr * decay ** (self.num_layers - 1 - l), } for l in range(self.num_layers - 1) ] # Language Model Head parameters opt_parameters += [ { "params": self.model.model.decoder.lm_head.layer_norm.parameters(), "lr": lr, }, {"params": self.model.model.decoder.lm_head.dense.parameters(), "lr": lr}, ] return opt_parameters @property def lm_head(self): """ Language modeling head. """ return self.model.model.decoder.lm_head @classmethod def from_pretrained(cls, hparams: Namespace): if not os.path.exists(saving_directory): os.makedirs(saving_directory) pretrained_model = hparams.pretrained_model if pretrained_model == "xlmr.base": download_file_maybe_extract( XLMR_BASE_URL, directory=saving_directory, check_files=[XLMR_BASE_MODEL_NAME], ) elif pretrained_model == "xlmr.large": download_file_maybe_extract( XLMR_LARGE_URL, directory=saving_directory, check_files=[XLMR_LARGE_MODEL_NAME], ) elif pretrained_model == "xlmr.base.v0": download_file_maybe_extract( XLMR_BASE_V0_URL, directory=saving_directory, check_files=[XLMR_BASE_V0_MODEL_NAME], ) elif pretrained_model == "xlmr.large.v0": download_file_maybe_extract( XLMR_LARGE_V0_URL, directory=saving_directory, check_files=[XLMR_LARGE_V0_MODEL_NAME], ) else: raise Exception(f"{pretrained_model} is an invalid XLM-R model.") xlmr = XLMRModel.from_pretrained( saving_directory + pretrained_model, checkpoint_file="model.pt" ) # xlmr.eval() tokenizer = XLMRTextEncoder( xlmr.encode, xlmr.task.source_dictionary.__dict__["indices"] ) return XLMREncoder(xlmr=xlmr, tokenizer=tokenizer, hparams=hparams) def forward( self, tokens: torch.Tensor, lengths: torch.Tensor ) -> Dict[str, torch.Tensor]: """ Encodes a batch of sequences. :param tokens: Torch tensor with the input sequences [batch_size x seq_len]. :param lengths: Torch tensor with the length of each sequence [seq_len]. :return: Dictionary with `sentemb` (tensor with dims [batch_size x output_units]), `wordemb` (tensor with dims [batch_size x seq_len x output_units]), `mask` (input mask), `all_layers` (List with word_embeddings from all layers), `extra` (tuple with all XLM-R layers). """ mask = lengths_to_mask(lengths, device=tokens.device) all_layers = self.model.extract_features(tokens, return_all_hiddens=True) return { "sentemb": all_layers[-1][:, 0, :], "wordemb": all_layers[-1], "all_layers": all_layers, "mask": mask, "extra": (all_layers), }