File size: 2,661 Bytes
f1660c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
from typing import Union, List, Optional, Tuple
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, AutoTokenizer
from transformers.utils.hub import cached_file
class SentencePieceJA(PreTrainedTokenizer):
def __init__(self,
model_path = "./tokenizer.json",
pad = "<PAD>",
bos = "<BOS>",
eos = "<EOS>",
unk = "<UNK>",
mask = "<MASK>",
**kwargs):
from tokenizers import Tokenizer
try:
self._tokenizer = Tokenizer.from_file(model_path)
except Exception as e:
print('exception: ', e)
print('load from cache...')
model_path = cached_file('if001/sentencepiece_ja', 'tokenizer.json')
self._tokenizer = Tokenizer.from_file(model_path)
super().__init__(**kwargs)
self.add_special_tokens({
'pad_token': pad,
'bos_token': bos,
'eos_token': eos,
'unk_token': unk,
'mask_token': mask
})
self._tokenizer.add_tokens([" ", " "])
def get_vocab(self) -> int:
return self._tokenizer.get_vocab()
@property
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()
def _tokenize(self, text, **kwargs):
return self._tokenizer.encode(text).tokens
def _convert_token_to_id(self, token):
ids = self._tokenizer.encode(token).ids
if len(ids) == 0:
return self.unk_token_id
return self._tokenizer.encode(token).ids[0]
def _convert_id_to_token(self, index: int) -> str:
return self._tokenizer.decode([index])
def convert_tokens_to_string(self, tokens: List[str]) -> str:
## 日本語用
return "".join(tokens)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
if index != token_index:
index = token_index
writer.write(token + "\n")
index += 1
return (vocab_file,) |