datnth1709 commited on
Commit
a079f98
·
1 Parent(s): 3792c6a

Upload envibert_tokenizer.py

Browse files
Files changed (1) hide show
  1. envibert_tokenizer.py +321 -0
envibert_tokenizer.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # !pip install sentencepiece==0.1.96 transformers==4.10.0
2
+ import sentencepiece as spm
3
+ import os
4
+ from transformers import PreTrainedTokenizer
5
+ from collections import Counter
6
+ from typing import List, Optional, Tuple
7
+
8
+
9
+ class RobertaTokenizer(PreTrainedTokenizer):
10
+ def __init__(
11
+ self,
12
+ pretrained_file,
13
+ bos_token="<s>",
14
+ eos_token="</s>",
15
+ sep_token="</s>",
16
+ cls_token="<s>",
17
+ unk_token="<unk>",
18
+ pad_token="<pad>",
19
+ mask_token="<mask>",
20
+ **kwargs
21
+ ):
22
+ super().__init__(
23
+ bos_token=bos_token,
24
+ eos_token=eos_token,
25
+ unk_token=unk_token,
26
+ sep_token=sep_token,
27
+ cls_token=cls_token,
28
+ pad_token=pad_token,
29
+ mask_token=mask_token,
30
+ **kwargs,
31
+ )
32
+
33
+ # load bpe model and vocab file
34
+ sentencepiece_model = os.path.join(pretrained_file, 'sentencepiece.bpe.model')
35
+ vocab_file = os.path.join(pretrained_file, 'dict.txt')
36
+ self.sp_model = spm.SentencePieceProcessor()
37
+ self.sp_model.Load(
38
+ sentencepiece_model) # please dont use anything from sp_model bcz it makes everything goes wrong
39
+
40
+ self.bpe_dict = Dictionary().load(vocab_file)
41
+
42
+ # Mimic fairseq token-to-id alignment for the first 4 token
43
+ self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
44
+
45
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
46
+ self.fairseq_offset = 0
47
+
48
+ self.fairseq_tokens_to_ids["<mask>"] = len(self.bpe_dict) + self.fairseq_offset
49
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
50
+
51
+ def _tokenize(self, text):
52
+ return self.sp_model.EncodeAsPieces(text)
53
+
54
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
55
+ #TODO
56
+ return "", ""
57
+
58
+ def _convert_token_to_id(self, token):
59
+ """ Converts a token (str) in an id using the vocab. """
60
+ if token in self.fairseq_tokens_to_ids:
61
+ return self.fairseq_tokens_to_ids[token]
62
+ spm_id = self.bpe_dict.index(token)
63
+ return spm_id
64
+
65
+ def _convert_id_to_token(self, index):
66
+ """Converts an index (integer) in a token (str) using the vocab."""
67
+ if index in self.fairseq_ids_to_tokens:
68
+ return self.fairseq_ids_to_tokens[index]
69
+ return self.bpe_dict[index]
70
+
71
+ def build_inputs_with_special_tokens(
72
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
73
+ ) -> List[int]:
74
+ """
75
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
76
+ adding special tokens.
77
+
78
+ This implementation does not add special tokens and this method should be overridden in a subclass.
79
+
80
+ Args:
81
+ token_ids_0 (:obj:`List[int]`): The first tokenized sequence.
82
+ token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence.
83
+
84
+ Returns:
85
+ :obj:`List[int]`: The model input with special tokens.
86
+ """
87
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
88
+
89
+ def create_token_type_ids_from_sequences(
90
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
91
+ ) -> List[int]:
92
+ """
93
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
94
+ not make use of token type ids, therefore a list of zeros is returned.
95
+
96
+ Args:
97
+ token_ids_0 (:obj:`List[int]`):
98
+ List of IDs.
99
+ token_ids_1 (:obj:`List[int]`, `optional`):
100
+ Optional second list of IDs for sequence pairs.
101
+
102
+ Returns:
103
+ :obj:`List[int]`: List of zeros.
104
+
105
+ """
106
+
107
+ sep = [self.sep_token_id]
108
+ cls = [self.cls_token_id]
109
+
110
+ return len(cls + token_ids_0 + sep) * [0]
111
+
112
+ @property
113
+ def vocab_size(self):
114
+ return len(self.bpe_dict) + self.fairseq_offset + 1 # Add the <mask> token
115
+
116
+ def get_vocab(self):
117
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
118
+ vocab.update(self.added_tokens_encoder)
119
+ return vocab
120
+
121
+
122
+ class Dictionary(object):
123
+ """A mapping from symbols to consecutive integers"""
124
+
125
+ def __init__(
126
+ self,
127
+ pad='<pad>',
128
+ eos='</s>',
129
+ unk='<unk>',
130
+ bos='<s>',
131
+ extra_special_symbols=None,
132
+ ):
133
+ self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
134
+ self.symbols = []
135
+ self.count = []
136
+ self.indices = {}
137
+ self.bos_index = self.add_symbol(bos)
138
+ self.pad_index = self.add_symbol(pad)
139
+ self.eos_index = self.add_symbol(eos)
140
+ self.unk_index = self.add_symbol(unk)
141
+ if extra_special_symbols:
142
+ for s in extra_special_symbols:
143
+ self.add_symbol(s)
144
+ self.nspecial = len(self.symbols)
145
+
146
+ def __eq__(self, other):
147
+ return self.indices == other.indices
148
+
149
+ def __getitem__(self, idx):
150
+ if idx < len(self.symbols):
151
+ return self.symbols[idx]
152
+ return self.unk_word
153
+
154
+ def __len__(self):
155
+ """Returns the number of symbols in the dictionary"""
156
+ return len(self.symbols)
157
+
158
+ def __contains__(self, sym):
159
+ return sym in self.indices
160
+
161
+ def index(self, sym):
162
+ """Returns the index of the specified symbol"""
163
+ assert isinstance(sym, str)
164
+ if sym in self.indices:
165
+ return self.indices[sym]
166
+ return self.unk_index
167
+
168
+ def unk_string(self, escape=False):
169
+ """Return unknown string, optionally escaped as: <<unk>>"""
170
+ if escape:
171
+ return '<{}>'.format(self.unk_word)
172
+ else:
173
+ return self.unk_word
174
+
175
+ def add_symbol(self, word, n=1):
176
+ """Adds a word to the dictionary"""
177
+ if word in self.indices:
178
+ idx = self.indices[word]
179
+ self.count[idx] = self.count[idx] + n
180
+ return idx
181
+ else:
182
+ idx = len(self.symbols)
183
+ self.indices[word] = idx
184
+ self.symbols.append(word)
185
+ self.count.append(n)
186
+ return idx
187
+
188
+ def update(self, new_dict):
189
+ """Updates counts from new dictionary."""
190
+ for word in new_dict.symbols:
191
+ idx2 = new_dict.indices[word]
192
+ if word in self.indices:
193
+ idx = self.indices[word]
194
+ self.count[idx] = self.count[idx] + new_dict.count[idx2]
195
+ else:
196
+ idx = len(self.symbols)
197
+ self.indices[word] = idx
198
+ self.symbols.append(word)
199
+ self.count.append(new_dict.count[idx2])
200
+
201
+ def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
202
+ """Sort symbols by frequency in descending order, ignoring special ones.
203
+
204
+ Args:
205
+ - threshold defines the minimum word count
206
+ - nwords defines the total number of words in the final dictionary,
207
+ including special symbols
208
+ - padding_factor can be used to pad the dictionary size to be a
209
+ multiple of 8, which is important on some hardware (e.g., Nvidia
210
+ Tensor Cores).
211
+ """
212
+ if nwords <= 0:
213
+ nwords = len(self)
214
+
215
+ new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
216
+ new_symbols = self.symbols[:self.nspecial]
217
+ new_count = self.count[:self.nspecial]
218
+
219
+ c = Counter(dict(sorted(zip(self.symbols[self.nspecial:], self.count[self.nspecial:]))))
220
+ for symbol, count in c.most_common(nwords - self.nspecial):
221
+ if count >= threshold:
222
+ new_indices[symbol] = len(new_symbols)
223
+ new_symbols.append(symbol)
224
+ new_count.append(count)
225
+ else:
226
+ break
227
+
228
+ threshold_nwords = len(new_symbols)
229
+ if padding_factor > 1:
230
+ i = 0
231
+ while threshold_nwords % padding_factor != 0:
232
+ symbol = 'madeupword{:04d}'.format(i)
233
+ new_indices[symbol] = len(new_symbols)
234
+ new_symbols.append(symbol)
235
+ new_count.append(0)
236
+ i += 1
237
+ threshold_nwords += 1
238
+
239
+ assert len(new_symbols) % padding_factor == 0
240
+ assert len(new_symbols) == len(new_indices)
241
+
242
+ self.count = list(new_count)
243
+ self.symbols = list(new_symbols)
244
+ self.indices = new_indices
245
+
246
+ def bos(self):
247
+ """Helper to get index of beginning-of-sentence symbol"""
248
+ return self.bos_index
249
+
250
+ def pad(self):
251
+ """Helper to get index of pad symbol"""
252
+ return self.pad_index
253
+
254
+ def eos(self):
255
+ """Helper to get index of end-of-sentence symbol"""
256
+ return self.eos_index
257
+
258
+ def unk(self):
259
+ """Helper to get index of unk symbol"""
260
+ return self.unk_index
261
+
262
+ @classmethod
263
+ def load(cls, f):
264
+ """Loads the dictionary from a text file with the format:
265
+
266
+ ```
267
+ <symbol0> <count0>
268
+ <symbol1> <count1>
269
+ ...
270
+ ```
271
+ """
272
+ d = cls()
273
+ d.add_from_file(f)
274
+ return d
275
+
276
+ def add_from_file(self, f):
277
+ """
278
+ Loads a pre-existing dictionary from a text file and adds its symbols
279
+ to this instance.
280
+ """
281
+ if isinstance(f, str):
282
+ try:
283
+ with open(f, 'r', encoding='utf-8') as fd:
284
+ self.add_from_file(fd)
285
+ except FileNotFoundError as fnfe:
286
+ raise fnfe
287
+ except UnicodeError:
288
+ raise Exception("Incorrect encoding detected in {}, please "
289
+ "rebuild the dataset".format(f))
290
+ return
291
+
292
+ lines = f.readlines()
293
+ indices_start_line = self._load_meta(lines)
294
+ for line in lines[indices_start_line:]:
295
+ idx = line.rfind(' ')
296
+ if idx == -1:
297
+ raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
298
+ word = line[:idx]
299
+ count = int(line[idx + 1:])
300
+ self.indices[word] = len(self.symbols)
301
+ self.symbols.append(word)
302
+ self.count.append(count)
303
+
304
+ def _save(self, f, kv_iterator):
305
+ if isinstance(f, str):
306
+ os.makedirs(os.path.dirname(f), exist_ok=True)
307
+ with open(f, 'w', encoding='utf-8') as fd:
308
+ return self.save(fd)
309
+ for k, v in kv_iterator:
310
+ print('{} {}'.format(k, v), file=f)
311
+
312
+ def _get_meta(self):
313
+ return [], []
314
+
315
+ def _load_meta(self, lines):
316
+ return 0
317
+
318
+ def save(self, f):
319
+ """Stores dictionary into a text file"""
320
+ ex_keys, ex_vals = self._get_meta()
321
+ self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:]))