from transformers import PreTrainedTokenizer import sentencepiece as spm import os from logging import getLogger from typing import List logger = getLogger() class SPTokenizer(PreTrainedTokenizer): """Tokenizing and encoding/decoding text using SentencePiece.""" def __init__(self, model_path: str, vocab_file: str, **kwargs): # Add 'vocab_file' to kwargs if it's not present if 'vocab_file' not in kwargs: kwargs['vocab_file'] = vocab_file super().__init__(**kwargs) # Store initialization arguments self.init_kwargs = {"model_path": model_path, "vocab_file": vocab_file, **kwargs} # Reload tokenizer assert os.path.isfile(model_path), model_path self.sp_model = spm.SentencePieceProcessor(model_file=model_path) logger.info(f"Reloaded SentencePiece model from {model_path}") # BOS / EOS token IDs self.n_words: int = self.sp_model.vocab_size() self.bos_id: int = self.sp_model.bos_id() self.eos_id: int = self.sp_model.eos_id() self.pad_id: int = self.sp_model.pad_id() logger.info( f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" ) assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() def encode(self, s: str, bos: bool, eos: bool) -> List[int]: t = self.sp_model.encode(s) if bos: t = [self.bos_id] + t if eos: t = t + [self.eos_id] return t def decode(self, t: List[int]) -> str: return self.sp_model.decode(t) @property def vocab_size(self): """Size of the base vocabulary (without the added tokens).""" return self.n_words def save_pretrained(self, save_directory: str, max_shard_size=None, safe_serialization=None): # Save the SentencePiece model file super().save_pretrained(save_directory, max_shard_size=None, safe_serialization=None) # Save the vocabulary to a file self.save_vocabulary(save_directory, filename_prefix="vocab") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): # Load the SentencePiece model file model_file = os.path.join(pretrained_model_name_or_path, "spmodel_wikiqa.model") # Load the vocabulary file path vocab_file = os.path.join(pretrained_model_name_or_path, "spmodel_wikiqa.vocab") return cls(model_path=model_file, vocab_file=vocab_file, **kwargs) #return super().from_pretrained(cls, pretrained_model_name_or_path, **kwargs) def get_vocab(self): # Read vocabulary from the file with open(self.init_kwargs["vocab_file"], "r", encoding="utf-8") as f: vocab = {word.strip(): i for i, word in enumerate(f)} return vocab def save_vocabulary(self, save_directory, filename_prefix): vocab_file = os.path.join(save_directory, f"{filename_prefix}.txt") with open(vocab_file, "w", encoding="utf-8") as f: for word, index in sorted(self.get_vocab().items(), key=lambda x: x[1]): f.write(f"{word}\n") return (vocab_file,)