File size: 2,661 Bytes
f1660c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os 
from typing import Union, List, Optional, Tuple

from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, AutoTokenizer
from transformers.utils.hub import cached_file
class SentencePieceJA(PreTrainedTokenizer):
    def __init__(self, 
                 model_path = "./tokenizer.json", 
                 pad = "<PAD>",
                 bos = "<BOS>",
                 eos = "<EOS>",
                 unk = "<UNK>",
                 mask = "<MASK>",
                 **kwargs):
        from tokenizers import Tokenizer
        try:             
            self._tokenizer = Tokenizer.from_file(model_path)
        except Exception as e:
            print('exception: ', e)
            print('load from cache...')
            model_path = cached_file('if001/sentencepiece_ja', 'tokenizer.json')        
            self._tokenizer = Tokenizer.from_file(model_path)
        super().__init__(**kwargs)
        self.add_special_tokens({
            'pad_token': pad,
            'bos_token': bos,
            'eos_token': eos,
            'unk_token': unk,
            'mask_token': mask
        })
        self._tokenizer.add_tokens([" ", " "])
    
    def get_vocab(self) -> int:
        return self._tokenizer.get_vocab()
    
    @property
    def vocab_size(self) -> int:
        return self._tokenizer.get_vocab_size()

    def _tokenize(self, text, **kwargs):
        return self._tokenizer.encode(text).tokens

    def _convert_token_to_id(self, token):
        ids = self._tokenizer.encode(token).ids
        if len(ids) == 0:
            return self.unk_token_id
        return self._tokenizer.encode(token).ids[0]
        
    def _convert_id_to_token(self, index: int) -> str:
        return self._tokenizer.decode([index])
    
    def convert_tokens_to_string(self, tokens: List[str]) -> str:        
        ## 日本語用
        return "".join(tokens)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        index = 0
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
            )
        else:
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory        
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
                if index != token_index:
                    index = token_index
                writer.write(token + "\n")
                index += 1
        return (vocab_file,)