|
from transformers import PreTrainedTokenizer |
|
import sentencepiece as spm |
|
import os |
|
from logging import getLogger |
|
from typing import List |
|
|
|
logger = getLogger() |
|
|
|
class SPTokenizer(PreTrainedTokenizer): |
|
"""Tokenizing and encoding/decoding text using SentencePiece.""" |
|
|
|
def __init__(self, model_path: str, vocab_file: str, **kwargs): |
|
|
|
if 'vocab_file' not in kwargs: |
|
kwargs['vocab_file'] = vocab_file |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.init_kwargs = {"model_path": model_path, "vocab_file": vocab_file, **kwargs} |
|
|
|
|
|
assert os.path.isfile(model_path), model_path |
|
self.sp_model = spm.SentencePieceProcessor(model_file=model_path) |
|
logger.info(f"Reloaded SentencePiece model from {model_path}") |
|
|
|
|
|
self.n_words: int = self.sp_model.vocab_size() |
|
self.bos_id: int = self.sp_model.bos_id() |
|
self.eos_id: int = self.sp_model.eos_id() |
|
self.pad_id: int = self.sp_model.pad_id() |
|
logger.info( |
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" |
|
) |
|
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() |
|
|
|
def encode(self, s: str, bos: bool, eos: bool) -> List[int]: |
|
t = self.sp_model.encode(s) |
|
if bos: |
|
t = [self.bos_id] + t |
|
if eos: |
|
t = t + [self.eos_id] |
|
return t |
|
|
|
def decode(self, t: List[int]) -> str: |
|
return self.sp_model.decode(t) |
|
@property |
|
def vocab_size(self): |
|
"""Size of the base vocabulary (without the added tokens).""" |
|
return self.n_words |
|
|
|
def save_pretrained(self, save_directory: str, max_shard_size=None, safe_serialization=None): |
|
|
|
super().save_pretrained(save_directory, max_shard_size=None, safe_serialization=None) |
|
|
|
|
|
|
|
self.save_vocabulary(save_directory, filename_prefix="vocab") |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
|
|
model_file = os.path.join(pretrained_model_name_or_path, "spmodel_wikiqa.model") |
|
|
|
|
|
vocab_file = os.path.join(pretrained_model_name_or_path, "spmodel_wikiqa.vocab") |
|
|
|
return cls(model_path=model_file, vocab_file=vocab_file, **kwargs) |
|
|
|
|
|
def get_vocab(self): |
|
|
|
with open(self.init_kwargs["vocab_file"], "r", encoding="utf-8") as f: |
|
vocab = {word.strip(): i for i, word in enumerate(f)} |
|
return vocab |
|
|
|
def save_vocabulary(self, save_directory, filename_prefix): |
|
vocab_file = os.path.join(save_directory, f"{filename_prefix}.txt") |
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
for word, index in sorted(self.get_vocab().items(), key=lambda x: x[1]): |
|
f.write(f"{word}\n") |
|
return (vocab_file,) |
|
|