Sentencepiece_tokenize / sentencepiece_tokenizer.py
amaanbadure's picture
Upload tokenizer
3638108
raw
history blame contribute delete
No virus
3.2 kB
from transformers import PreTrainedTokenizer
import sentencepiece as spm
import os
from logging import getLogger
from typing import List
logger = getLogger()
class SPTokenizer(PreTrainedTokenizer):
"""Tokenizing and encoding/decoding text using SentencePiece."""
def __init__(self, model_path: str, vocab_file: str, **kwargs):
# Add 'vocab_file' to kwargs if it's not present
if 'vocab_file' not in kwargs:
kwargs['vocab_file'] = vocab_file
super().__init__(**kwargs)
# Store initialization arguments
self.init_kwargs = {"model_path": model_path, "vocab_file": vocab_file, **kwargs}
# Reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = spm.SentencePieceProcessor(model_file=model_path)
logger.info(f"Reloaded SentencePiece model from {model_path}")
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
return self.sp_model.decode(t)
@property
def vocab_size(self):
"""Size of the base vocabulary (without the added tokens)."""
return self.n_words
def save_pretrained(self, save_directory: str, max_shard_size=None, safe_serialization=None):
# Save the SentencePiece model file
super().save_pretrained(save_directory, max_shard_size=None, safe_serialization=None)
# Save the vocabulary to a file
self.save_vocabulary(save_directory, filename_prefix="vocab")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
# Load the SentencePiece model file
model_file = os.path.join(pretrained_model_name_or_path, "spmodel_wikiqa.model")
# Load the vocabulary file path
vocab_file = os.path.join(pretrained_model_name_or_path, "spmodel_wikiqa.vocab")
return cls(model_path=model_file, vocab_file=vocab_file, **kwargs)
#return super().from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
def get_vocab(self):
# Read vocabulary from the file
with open(self.init_kwargs["vocab_file"], "r", encoding="utf-8") as f:
vocab = {word.strip(): i for i, word in enumerate(f)}
return vocab
def save_vocabulary(self, save_directory, filename_prefix):
vocab_file = os.path.join(save_directory, f"{filename_prefix}.txt")
with open(vocab_file, "w", encoding="utf-8") as f:
for word, index in sorted(self.get_vocab().items(), key=lambda x: x[1]):
f.write(f"{word}\n")
return (vocab_file,)