VarunGumma commited on
Commit
53ee433
1 Parent(s): 684893f

Upload tokenization_indictrans.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tokenization_indictrans.py +261 -0
tokenization_indictrans.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ from transformers.utils import logging
7
+ from sentencepiece import SentencePieceProcessor
8
+ from transformers.tokenization_utils import PreTrainedTokenizer
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ SPIECE_UNDERLINE = "▁"
14
+
15
+ SPECIAL_TAGS = {
16
+ "_bt_",
17
+ "_ft_",
18
+ "asm_Beng",
19
+ "awa_Deva",
20
+ "ben_Beng",
21
+ "bho_Deva",
22
+ "brx_Deva",
23
+ "doi_Deva",
24
+ "eng_Latn",
25
+ "gom_Deva",
26
+ "gon_Deva",
27
+ "guj_Gujr",
28
+ "hin_Deva",
29
+ "hne_Deva",
30
+ "kan_Knda",
31
+ "kas_Arab",
32
+ "kas_Deva",
33
+ "kha_Latn",
34
+ "lus_Latn",
35
+ "mag_Deva",
36
+ "mai_Deva",
37
+ "mal_Mlym",
38
+ "mar_Deva",
39
+ "mni_Beng",
40
+ "mni_Mtei",
41
+ "npi_Deva",
42
+ "ory_Orya",
43
+ "pan_Guru",
44
+ "san_Deva",
45
+ "sat_Olck",
46
+ "snd_Arab",
47
+ "snd_Deva",
48
+ "tam_Taml",
49
+ "tel_Telu",
50
+ "urd_Arab",
51
+ "unr_Deva",
52
+ }
53
+
54
+ VOCAB_FILES_NAMES = {
55
+ "src_vocab_fp": "dict.SRC.json",
56
+ "tgt_vocab_fp": "dict.TGT.json",
57
+ "src_spm_fp": "model.SRC",
58
+ "tgt_spm_fp": "model.TGT",
59
+ }
60
+
61
+
62
+ class IndicTransTokenizer(PreTrainedTokenizer):
63
+ _added_tokens_encoder = {}
64
+ _added_tokens_decoder = {}
65
+
66
+ vocab_files_names = VOCAB_FILES_NAMES
67
+ model_input_names = ["input_ids", "attention_mask"]
68
+
69
+ def __init__(
70
+ self,
71
+ src_vocab_fp=None,
72
+ tgt_vocab_fp=None,
73
+ src_spm_fp=None,
74
+ tgt_spm_fp=None,
75
+ unk_token="<unk>",
76
+ bos_token="<s>",
77
+ eos_token="</s>",
78
+ pad_token="<pad>",
79
+ do_lower_case=False,
80
+ **kwargs,
81
+ ):
82
+
83
+ self.src = True
84
+
85
+ self.src_vocab_fp = src_vocab_fp
86
+ self.tgt_vocab_fp = tgt_vocab_fp
87
+ self.src_spm_fp = src_spm_fp
88
+ self.tgt_spm_fp = tgt_spm_fp
89
+
90
+ self.unk_token = unk_token
91
+ self.pad_token = pad_token
92
+ self.eos_token = eos_token
93
+ self.bos_token = bos_token
94
+
95
+ self.encoder = self._load_json(self.src_vocab_fp)
96
+ if self.unk_token not in self.encoder:
97
+ raise KeyError("<unk> token must be in vocab")
98
+ assert self.pad_token in self.encoder
99
+ self.encoder_rev = {v: k for k, v in self.encoder.items()}
100
+
101
+ self.decoder = self._load_json(self.tgt_vocab_fp)
102
+ if self.unk_token not in self.encoder:
103
+ raise KeyError("<unk> token must be in vocab")
104
+ assert self.pad_token in self.encoder
105
+ self.decoder_rev = {v: k for k, v in self.decoder.items()}
106
+
107
+ # load SentencePiece model for pre-processing
108
+ self.src_spm = self._load_spm(self.src_spm_fp)
109
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
110
+
111
+ self.current_spm = self.src_spm
112
+ self.current_encoder = self.encoder
113
+ self.current_encoder_rev = self.encoder_rev
114
+
115
+ self.unk_token_id = self.encoder[self.unk_token]
116
+ self.pad_token_id = self.encoder[self.pad_token]
117
+ self.eos_token_id = self.encoder[self.eos_token]
118
+ self.bos_token_id = self.encoder[self.bos_token]
119
+
120
+ super().__init__(
121
+ src_vocab_file=self.src_vocab_fp,
122
+ tgt_vocab_file=self.src_vocab_fp,
123
+ do_lower_case=do_lower_case,
124
+ unk_token=unk_token,
125
+ bos_token=bos_token,
126
+ eos_token=eos_token,
127
+ pad_token=pad_token,
128
+ **kwargs,
129
+ )
130
+
131
+ def add_new_special_tags(self, new_tags: List[str]):
132
+ SPECIAL_TAGS.update(new_tags)
133
+
134
+ def _switch_to_input_mode(self):
135
+ self.src = True
136
+ self.padding_side = "left"
137
+ self.current_spm = self.src_spm
138
+ self.current_encoder = self.encoder
139
+ self.current_encoder_rev = self.encoder_rev
140
+
141
+ def _switch_to_target_mode(self):
142
+ self.src = False
143
+ self.padding_side = "right"
144
+ self.current_spm = self.tgt_spm
145
+ self.current_encoder = self.decoder
146
+ self.current_encoder_rev = self.decoder_rev
147
+
148
+ def _load_spm(self, path: str) -> SentencePieceProcessor:
149
+ return SentencePieceProcessor(model_file=path)
150
+
151
+ def _save_json(self, data, path: str) -> None:
152
+ with open(path, "w", encoding="utf-8") as f:
153
+ json.dump(data, f, indent=2)
154
+
155
+ def _load_json(self, path: str) -> Union[Dict, List]:
156
+ with open(path, "r", encoding="utf-8") as f:
157
+ return json.load(f)
158
+
159
+ def _split_tags(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
160
+ tags = [token for token in tokens if token in SPECIAL_TAGS]
161
+ tokens = [token for token in tokens if token not in SPECIAL_TAGS]
162
+ return tags, tokens
163
+
164
+ def _split_pads(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
165
+ pads = [token for token in tokens if token == self.pad_token]
166
+ tokens = [token for token in tokens if token != self.pad_token]
167
+ return pads, tokens
168
+
169
+ @property
170
+ def src_vocab_size(self) -> int:
171
+ return len(self.encoder)
172
+
173
+ @property
174
+ def tgt_vocab_size(self) -> int:
175
+ return len(self.decoder)
176
+
177
+ def get_src_vocab(self) -> Dict[str, int]:
178
+ return dict(self.encoder, **self.added_tokens_encoder)
179
+
180
+ def get_tgt_vocab(self) -> Dict[str, int]:
181
+ return dict(self.decoder, **self.added_tokens_decoder)
182
+
183
+ # hack override
184
+ def get_vocab(self) -> Dict[str, int]:
185
+ return self.get_src_vocab()
186
+
187
+ # hack override
188
+ @property
189
+ def vocab_size(self) -> int:
190
+ return self.src_vocab_size
191
+
192
+ def _convert_token_to_id(self, token: str) -> int:
193
+ """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
194
+ return self.current_encoder.get(token, self.current_encoder[self.unk_token])
195
+
196
+ def _convert_id_to_token(self, index: int) -> str:
197
+ """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
198
+ return self.current_encoder_rev.get(index, self.unk_token)
199
+
200
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
201
+ """Uses sentencepiece model for detokenization"""
202
+ pads, tokens = self._split_pads(tokens)
203
+
204
+ if self.src:
205
+
206
+ tags, non_tags = self._split_tags(tokens)
207
+
208
+ return (
209
+ " ".join(pads)
210
+ + " "
211
+ + " ".join(tags)
212
+ + " "
213
+ + "".join(non_tags).replace(SPIECE_UNDERLINE, " ").strip()
214
+ )
215
+
216
+ return (
217
+ "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
218
+ + " "
219
+ + " ".join(pads)
220
+ )
221
+
222
+ def _tokenize(self, text) -> List[str]:
223
+ if self.src:
224
+ tokens = text.split(" ")
225
+ tags, non_tags = self._split_tags(tokens)
226
+ text = " ".join(non_tags)
227
+ tokens = self.current_spm.EncodeAsPieces(text)
228
+ return tags + tokens
229
+ else:
230
+ return self.current_spm.EncodeAsPieces(text)
231
+
232
+ def build_inputs_with_special_tokens(
233
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
+ ) -> List[int]:
235
+ if token_ids_1 is None:
236
+ return token_ids_0 + [self.eos_token_id]
237
+ # We don't expect to process pairs, but leave the pair logic for API consistency
238
+ return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
239
+
240
+ def save_vocabulary(
241
+ self, save_directory: str, filename_prefix: Optional[str] = None
242
+ ) -> Tuple[str]:
243
+ if not os.path.isdir(save_directory):
244
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
245
+ return
246
+
247
+ src_spm_fp = os.path.join(save_directory, "model.SRC")
248
+ tgt_spm_fp = os.path.join(save_directory, "model.TGT")
249
+ src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
250
+ tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
251
+
252
+ self._save_json(self.encoder, src_vocab_fp)
253
+ self._save_json(self.decoder, tgt_vocab_fp)
254
+
255
+ with open(src_spm_fp, "wb") as f:
256
+ f.write(self.src_spm.serialized_model_proto())
257
+
258
+ with open(tgt_spm_fp, "wb") as f:
259
+ f.write(self.tgt_spm.serialized_model_proto())
260
+
261
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp