cointegrated commited on
Commit
b9a4929
1 Parent(s): 5b483a1

Create char_tokenizer.py

Browse files
Files changed (1) hide show
  1. char_tokenizer.py +162 -0
char_tokenizer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly copypasted from
3
+ https://huggingface.co/IlyaGusev/ru-word-stress-transformer/blob/main/char_tokenizer.py
4
+ with Apache 2.0 license
5
+ """
6
+
7
+ import os
8
+ from typing import Optional, Tuple, List
9
+ from collections import OrderedDict
10
+
11
+ from torch.utils.data import Dataset
12
+ from transformers import PreTrainedTokenizer, AutoTokenizer
13
+
14
+
15
+ def load_vocab(vocab_file):
16
+ vocab = OrderedDict()
17
+ with open(vocab_file, "r", encoding="utf-8") as reader:
18
+ tokens = reader.readlines()
19
+ for index, token in enumerate(tokens):
20
+ token = token.rstrip("\n")
21
+ vocab[token] = index
22
+ return vocab
23
+
24
+
25
+ class CharTokenizer(PreTrainedTokenizer):
26
+ vocab_files_names = {"vocab_file": "vocab.txt"}
27
+
28
+ def __init__(
29
+ self,
30
+ vocab_file=None,
31
+ pad_token="[pad]",
32
+ unk_token="[unk]",
33
+ bos_token="[bos]",
34
+ eos_token="[eos]",
35
+ cls_token="[cls]",
36
+ sep_token="[sep]",
37
+ mask_token="[mask]",
38
+ space_token="▁",
39
+ do_lower_case=False,
40
+ *args,
41
+ **kwargs
42
+ ):
43
+ super().__init__(
44
+ pad_token=pad_token,
45
+ unk_token=unk_token,
46
+ bos_token=bos_token,
47
+ eos_token=eos_token,
48
+ cls_token=cls_token,
49
+ mask_token=mask_token,
50
+ do_lower_case=do_lower_case,
51
+ **kwargs
52
+ )
53
+ self.do_lower_case = do_lower_case
54
+ self.space_token = space_token
55
+
56
+ if not vocab_file or not os.path.isfile(vocab_file):
57
+ self.vocab = OrderedDict()
58
+ self.ids_to_tokens = OrderedDict()
59
+ else:
60
+ self.vocab = load_vocab(vocab_file)
61
+ self.ids_to_tokens = OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
62
+
63
+ def train(self, file_path):
64
+ vocab = set()
65
+ with open(file_path) as r:
66
+ for line in r:
67
+ word = line.strip()
68
+ if self.do_lower_case:
69
+ word = word.lower()
70
+ vocab |= set(word)
71
+ vocab = list(vocab)
72
+ vocab.sort()
73
+ special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
74
+ vocab = special_tokens + vocab
75
+
76
+ for i, ch in enumerate(vocab):
77
+ self.vocab[ch] = i
78
+ self.ids_to_tokens = vocab
79
+
80
+ @property
81
+ def vocab_size(self):
82
+ return len(self.vocab)
83
+
84
+ def get_vocab(self):
85
+ return self.vocab
86
+
87
+ def _convert_token_to_id(self, token):
88
+ if self.do_lower_case:
89
+ token = token.lower()
90
+ return self.vocab.get(token, self.vocab[self.unk_token])
91
+
92
+ def _convert_id_to_token(self, index):
93
+ return self.ids_to_tokens[index]
94
+
95
+ def prepare_for_tokenization(
96
+ self, text, is_split_into_words: bool = False, spaces=0, **kwargs
97
+ ):
98
+ if spaces:
99
+ pad = self.space_token * spaces
100
+ text = pad + pad.join(text) + pad
101
+ return (text, kwargs)
102
+
103
+ def _tokenize(self, text, spaces=0):
104
+ if self.do_lower_case:
105
+ text = text.lower()
106
+ return list(text)
107
+
108
+ def convert_tokens_to_string(self, tokens):
109
+ return "".join(tokens)
110
+
111
+ def build_inputs_with_special_tokens(
112
+ self,
113
+ token_ids_0: List[int],
114
+ token_ids_1: Optional[List[int]] = None
115
+ ) -> List[int]:
116
+ bos = [self.bos_token_id]
117
+ eos = [self.eos_token_id]
118
+ return bos + token_ids_0 + eos
119
+
120
+ def get_special_tokens_mask(
121
+ self,
122
+ token_ids_0: List[int],
123
+ token_ids_1: Optional[List[int]] = None
124
+ ) -> List[int]:
125
+ return [1] + ([0] * len(token_ids_0)) + [1]
126
+
127
+ def create_token_type_ids_from_sequences(
128
+ self,
129
+ token_ids_0: List[int],
130
+ token_ids_1: Optional[List[int]] = None
131
+ ) -> List[int]:
132
+ return (len(token_ids_0) + 2) * [0]
133
+
134
+ def save_vocabulary(
135
+ self,
136
+ save_directory: str,
137
+ filename_prefix: Optional[str] = None
138
+ ) -> Tuple[str]:
139
+ assert os.path.isdir(save_directory)
140
+ vocab_file = os.path.join(
141
+ save_directory,
142
+ (filename_prefix + "-" if filename_prefix else "") +
143
+ self.vocab_files_names["vocab_file"]
144
+ )
145
+ index = 0
146
+ with open(vocab_file, "w", encoding="utf-8") as writer:
147
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
148
+ assert index == token_index
149
+ writer.write(token + "\n")
150
+ index += 1
151
+ return (vocab_file,)
152
+
153
+ def clean_up_tokenization(self, text, space='▁'):
154
+ res = []
155
+ prev = space
156
+ for c in text:
157
+ if c != prev and c != space:
158
+ res.append(c)
159
+ prev = c
160
+ return ''.join(res)
161
+
162
+ AutoTokenizer.register("char_tokenizer", CharTokenizer)