# -*- coding: utf-8 -*- from typing import Callable, Dict import torch from torchnlp.encoders.text.text_encoder import TextEncoder from .tokenizer_base import TextEncoderBase class XLMRTextEncoder(TextEncoderBase): """ XLM-RoBERTa encoder from Fairseq. :param tokenizer_func: XLM tokenization function. This can be easily obtain from the fairseq model (e.g: xlmr.encode callable) :param vocabulary: the dictionary containing the XLM-R vocabulary. This can be easily obtain from the fairseq model (e.g: xlmr.task.source_dictionary.__dict__['indices']) """ def __init__(self, encode_func: Callable, vocabulary: Dict[str, int]) -> None: super().__init__() self.encode_func = encode_func # Properties from the base class self.stoi = vocabulary self.itos = {v: k for k, v in vocabulary.items()} self._pad_index = self.stoi[""] self._eos_index = self.stoi[""] self._unk_index = self.stoi[""] self._bos_index = self.stoi[""] self._mask_index = self.stoi[""] def encode(self, sequence: str) -> torch.Tensor: """Encodes a 'sequence'. :param sequence: String 'sequence' to encode. Returns: - torch.Tensor: Encoding of the 'sequence'. """ sequence = TextEncoder.encode(self, sequence) return self.encode_func(sequence)