|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers.models.xlm_roberta.tokenization_xlm_roberta import XLMRobertaTokenizer |
|
|
|
SPIECE_UNDERLINE = "▁" |
|
|
|
VOCAB_FILES_NAMES = {"spm_model": "spm.model", "custom_vocab_file": "dict.txt"} |
|
|
|
PRETRAINED_VOCAB_FILES_MAP = { |
|
"spm_model": { |
|
"fairseq-roberta-spm-normal": "fairseq-roberta-all-model/spm.model", |
|
}, |
|
"custom_vocab_file": { |
|
"fairseq-roberta-spm-normal": "fairseq-roberta-all-model/dict.txt", |
|
} |
|
} |
|
|
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { |
|
"fairseq-roberta-spm-normal": 512, |
|
} |
|
|
|
|
|
class FairSeqRobertaSentencePieceTokenizer(XLMRobertaTokenizer): |
|
|
|
vocab_files_names = VOCAB_FILES_NAMES |
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
|
|
|
def __init__( |
|
self, |
|
spm_model, |
|
custom_vocab_file, |
|
bos_token="[CLS]", |
|
eos_token="[SEP]", |
|
sep_token="[SEP]", |
|
cls_token="[CLS]", |
|
unk_token="[UNK]", |
|
pad_token="[PAD]", |
|
mask_token="[MASK]", |
|
**kwargs |
|
): |
|
super().__init__( |
|
vocab_file=spm_model, |
|
bos_token=bos_token, |
|
eos_token=eos_token, |
|
unk_token=unk_token, |
|
sep_token=sep_token, |
|
cls_token=cls_token, |
|
pad_token=pad_token, |
|
mask_token=mask_token, |
|
**kwargs, |
|
) |
|
|
|
|
|
self.symbols = [] |
|
self.count = [] |
|
self.spm_id_to_fairseq_id = {} |
|
self._add_symbol(self.sp_model.PieceToId(bos_token)) |
|
self._add_symbol(self.sp_model.PieceToId(pad_token)) |
|
self._add_symbol(self.sp_model.PieceToId(eos_token)) |
|
self._add_symbol(self.sp_model.PieceToId(unk_token)) |
|
self._add_from_file(custom_vocab_file) |
|
self._add_symbol(self.sp_model.PieceToId(mask_token)) |
|
|
|
self.fairseq_tokens_to_ids = {} |
|
self.fairseq_tokens_to_ids = self._build_fairseq_tokens_to_ids() |
|
|
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} |
|
|
|
|
|
self._num_tokens_converted = 0 |
|
self._num_tokens_oov = 0 |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.symbols) |
|
|
|
@property |
|
def pad_token_id(self): |
|
return self.fairseq_tokens_to_ids.get(self.pad_token) |
|
|
|
@property |
|
def unk_token_id(self): |
|
return self.fairseq_tokens_to_ids.get(self.unk_token) |
|
|
|
def reset_stats(self): |
|
self._num_tokens_converted = 0 |
|
self._num_tokens_oov = 0 |
|
|
|
def get_stats(self): |
|
oov_rate = self._num_tokens_oov / self._num_tokens_converted |
|
result = { |
|
"total": self._num_tokens_converted, |
|
"oov": self._num_tokens_oov, |
|
"oov_rate": oov_rate |
|
} |
|
return result |
|
|
|
def _convert_token_to_id(self, token): |
|
""" Converts a token (str) in an id using the vocab. """ |
|
self._num_tokens_converted += 1 |
|
if token in self.fairseq_tokens_to_ids: |
|
return self.fairseq_tokens_to_ids[token] |
|
else: |
|
self._num_tokens_oov += 1 |
|
return self.unk_token_id |
|
|
|
def _convert_id_to_token(self, index): |
|
"""Converts an index (integer) in a token (str) using the vocab.""" |
|
if index in self.fairseq_ids_to_tokens: |
|
return self.fairseq_ids_to_tokens[index] |
|
else: |
|
return self.unk_token |
|
|
|
def _add_from_file(self, f): |
|
""" |
|
Source: FairSeq Dictionary class. |
|
Loads a pre-existing dictionary from a text file and adds its symbols |
|
to this instance. |
|
""" |
|
if isinstance(f, str): |
|
try: |
|
with open(f, "r", encoding="utf-8") as fd: |
|
self._add_from_file(fd) |
|
except FileNotFoundError as fnfe: |
|
raise fnfe |
|
except UnicodeError: |
|
raise Exception( |
|
"Incorrect encoding detected in {}, please " |
|
"rebuild the dataset".format(f) |
|
) |
|
return |
|
|
|
lines = f.readlines() |
|
indices_start_line = 0 |
|
|
|
for line in lines[indices_start_line:]: |
|
try: |
|
line, field = line.rstrip().rsplit(" ", 1) |
|
if field == "#fairseq:overwrite": |
|
overwrite = True |
|
line, field = line.rsplit(" ", 1) |
|
else: |
|
overwrite = False |
|
count = int(field) |
|
spm_id = line |
|
if spm_id in self.spm_id_to_fairseq_id and not overwrite: |
|
raise RuntimeError( |
|
"Duplicate word found when loading Dictionary: '{}'. " |
|
"Duplicate words can overwrite earlier ones by adding the " |
|
"#fairseq:overwrite flag at the end of the corresponding row " |
|
"in the dictionary file. If using the Camembert model, please " |
|
"download an updated copy of the model file." |
|
.format(spm_id) |
|
) |
|
self._add_symbol(spm_id, n=count, overwrite=overwrite) |
|
except ValueError: |
|
raise ValueError( |
|
"Incorrect dictionary format, expected '<token> <cnt> [flags]'" |
|
) |
|
|
|
def _add_symbol(self, spm_id, n=1, overwrite=False): |
|
""" |
|
Source: FairSeq Dictionary class. |
|
Adds a word to the dictionary |
|
""" |
|
if spm_id in self.spm_id_to_fairseq_id and not overwrite: |
|
idx = self.spm_id_to_fairseq_id[spm_id] |
|
self.count[idx] = self.count[idx] + n |
|
return idx |
|
else: |
|
idx = len(self.symbols) |
|
self.spm_id_to_fairseq_id[spm_id] = idx |
|
self.symbols.append(spm_id) |
|
self.count.append(n) |
|
return idx |
|
|
|
def _build_fairseq_tokens_to_ids(self): |
|
|
|
fairseq_tokens_to_ids = self.fairseq_tokens_to_ids |
|
for spm_id, fairseq_id in self.spm_id_to_fairseq_id.items(): |
|
if isinstance(spm_id, str) and "madeup" in spm_id: |
|
print("[PASS] spm_id: {} | fairseq_id: {}".format(spm_id, fairseq_id)) |
|
continue |
|
token = self.sp_model.IdToPiece(int(spm_id)) |
|
|
|
fairseq_tokens_to_ids[str(token)] = fairseq_id |
|
return fairseq_tokens_to_ids |
|
|