|
import os |
|
import torch |
|
|
|
class SPieceTokenizer: |
|
add_eos = True |
|
|
|
@staticmethod |
|
def from_pretrained(path): |
|
return SPieceTokenizer(path) |
|
|
|
def __init__(self, tokenizer_path): |
|
import sentencepiece |
|
if torch.is_tensor(tokenizer_path): |
|
tokenizer_path = tokenizer_path.numpy().tobytes() |
|
|
|
if isinstance(tokenizer_path, bytes): |
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos) |
|
else: |
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos) |
|
|
|
def get_vocab(self): |
|
out = {} |
|
for i in range(self.tokenizer.get_piece_size()): |
|
out[self.tokenizer.id_to_piece(i)] = i |
|
return out |
|
|
|
def __call__(self, string): |
|
out = self.tokenizer.encode(string) |
|
return {"input_ids": out} |
|
|
|
def serialize_model(self): |
|
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto())) |
|
|