Update tokenization_indictrans.py

#7
by VarunGumma - opened
Files changed (1) hide show
  1. tokenization_indictrans.py +133 -140
tokenization_indictrans.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
  import json
 
3
 
 
4
  from typing import Dict, List, Optional, Union, Tuple
5
 
6
- from transformers.utils import logging
7
  from sentencepiece import SentencePieceProcessor
8
  from transformers.tokenization_utils import PreTrainedTokenizer
9
 
@@ -12,44 +13,45 @@ logger = logging.get_logger(__name__)
12
 
13
  SPIECE_UNDERLINE = "▁"
14
 
15
- SPECIAL_TAGS = {
16
- "_bt_",
17
- "_ft_",
18
- "asm_Beng",
19
- "awa_Deva",
20
- "ben_Beng",
21
- "bho_Deva",
22
- "brx_Deva",
23
- "doi_Deva",
24
- "eng_Latn",
25
- "gom_Deva",
26
- "gon_Deva",
27
- "guj_Gujr",
28
- "hin_Deva",
29
- "hne_Deva",
30
- "kan_Knda",
31
- "kas_Arab",
32
- "kas_Deva",
33
- "kha_Latn",
34
- "lus_Latn",
35
- "mag_Deva",
36
- "mai_Deva",
37
- "mal_Mlym",
38
- "mar_Deva",
39
- "mni_Beng",
40
- "mni_Mtei",
41
- "npi_Deva",
42
- "ory_Orya",
43
- "pan_Guru",
44
- "san_Deva",
45
- "sat_Olck",
46
- "snd_Arab",
47
- "snd_Deva",
48
- "tam_Taml",
49
- "tel_Telu",
50
- "urd_Arab",
51
- "unr_Deva",
52
- }
 
53
 
54
  VOCAB_FILES_NAMES = {
55
  "src_vocab_fp": "dict.SRC.json",
@@ -60,9 +62,8 @@ VOCAB_FILES_NAMES = {
60
 
61
 
62
  class IndicTransTokenizer(PreTrainedTokenizer):
63
- _added_tokens_encoder = {}
64
- _added_tokens_decoder = {}
65
-
66
  vocab_files_names = VOCAB_FILES_NAMES
67
  model_input_names = ["input_ids", "attention_mask"]
68
 
@@ -79,43 +80,51 @@ class IndicTransTokenizer(PreTrainedTokenizer):
79
  do_lower_case=False,
80
  **kwargs,
81
  ):
82
-
83
- self.src = True
84
-
85
  self.src_vocab_fp = src_vocab_fp
86
  self.tgt_vocab_fp = tgt_vocab_fp
87
  self.src_spm_fp = src_spm_fp
88
  self.tgt_spm_fp = tgt_spm_fp
89
 
90
- self.unk_token = unk_token.content
91
- self.pad_token = pad_token.content
92
- self.eos_token = eos_token.content
93
- self.bos_token = bos_token.content
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- self.encoder = self._load_json(self.src_vocab_fp)
96
- if self.unk_token not in self.encoder:
97
  raise KeyError("<unk> token must be in vocab")
98
- assert self.pad_token in self.encoder
99
- self.encoder_rev = {v: k for k, v in self.encoder.items()}
100
 
101
- self.decoder = self._load_json(self.tgt_vocab_fp)
102
- if self.unk_token not in self.encoder:
103
- raise KeyError("<unk> token must be in vocab")
104
- assert self.pad_token in self.encoder
105
- self.decoder_rev = {v: k for k, v in self.decoder.items()}
106
 
107
- # load SentencePiece model for pre-processing
108
  self.src_spm = self._load_spm(self.src_spm_fp)
109
  self.tgt_spm = self._load_spm(self.tgt_spm_fp)
110
 
111
- self.current_spm = self.src_spm
112
- self.current_encoder = self.encoder
113
- self.current_encoder_rev = self.encoder_rev
114
 
115
- self.unk_token_id = self.encoder[self.unk_token]
116
- self.pad_token_id = self.encoder[self.pad_token]
117
- self.eos_token_id = self.encoder[self.eos_token]
118
- self.bos_token_id = self.encoder[self.bos_token]
 
119
 
120
  super().__init__(
121
  src_vocab_file=self.src_vocab_fp,
@@ -128,134 +137,118 @@ class IndicTransTokenizer(PreTrainedTokenizer):
128
  **kwargs,
129
  )
130
 
131
- def add_new_special_tags(self, new_tags: List[str]):
132
- SPECIAL_TAGS.update(new_tags)
 
133
 
134
- def _switch_to_input_mode(self):
135
- self.src = True
136
  self.padding_side = "left"
137
- self.current_spm = self.src_spm
138
- self.current_encoder = self.encoder
139
- self.current_encoder_rev = self.encoder_rev
140
 
141
- def _switch_to_target_mode(self):
142
- self.src = False
143
  self.padding_side = "right"
144
- self.current_spm = self.tgt_spm
145
- self.current_encoder = self.decoder
146
- self.current_encoder_rev = self.decoder_rev
147
 
148
- def _load_spm(self, path: str) -> SentencePieceProcessor:
 
149
  return SentencePieceProcessor(model_file=path)
150
 
151
- def _save_json(self, data, path: str) -> None:
 
152
  with open(path, "w", encoding="utf-8") as f:
153
  json.dump(data, f, indent=2)
154
 
155
- def _load_json(self, path: str) -> Union[Dict, List]:
 
156
  with open(path, "r", encoding="utf-8") as f:
157
  return json.load(f)
158
 
159
- def _split_tags(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
160
- tags = [token for token in tokens if token in SPECIAL_TAGS]
161
- tokens = [token for token in tokens if token not in SPECIAL_TAGS]
162
- return tags, tokens
163
-
164
- def _split_pads(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
165
- pads = [token for token in tokens if token == self.pad_token]
166
- tokens = [token for token in tokens if token != self.pad_token]
167
- return pads, tokens
168
-
169
  @property
170
  def src_vocab_size(self) -> int:
171
- return len(self.encoder)
172
 
173
  @property
174
  def tgt_vocab_size(self) -> int:
175
- return len(self.decoder)
176
 
177
  def get_src_vocab(self) -> Dict[str, int]:
178
- return dict(self.encoder, **self.added_tokens_encoder)
179
 
180
  def get_tgt_vocab(self) -> Dict[str, int]:
181
- return dict(self.decoder, **self.added_tokens_decoder)
182
 
183
- # hack override
184
  def get_vocab(self) -> Dict[str, int]:
185
  return self.get_src_vocab()
186
 
187
- # hack override
188
  @property
189
  def vocab_size(self) -> int:
190
  return self.src_vocab_size
191
 
 
192
  def _convert_token_to_id(self, token: str) -> int:
193
- """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
194
- return self.current_encoder.get(token, self.current_encoder[self.unk_token])
195
 
 
196
  def _convert_id_to_token(self, index: int) -> str:
197
- """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
198
- return self.current_encoder_rev.get(index, self.unk_token)
199
 
200
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
201
- """Uses sentencepiece model for detokenization"""
202
- pads, tokens = self._split_pads(tokens)
203
-
204
- if self.src:
205
 
206
- tags, non_tags = self._split_tags(tokens)
 
 
207
 
208
- return (
209
- " ".join(pads)
210
- + " "
211
- + " ".join(tags)
212
- + " "
213
- + "".join(non_tags).replace(SPIECE_UNDERLINE, " ").strip()
214
- )
215
 
216
- return (
217
- "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
218
- + " "
219
- + " ".join(pads)
 
 
 
 
 
 
 
 
 
 
 
220
  )
221
-
222
- def _tokenize(self, text) -> List[str]:
223
- if self.src:
224
- tokens = text.split(" ")
225
- tags, non_tags = self._split_tags(tokens)
226
- text = " ".join(non_tags)
227
- tokens = self.current_spm.EncodeAsPieces(text)
228
- return tags + tokens
229
- else:
230
- return self.current_spm.EncodeAsPieces(text)
231
 
232
  def build_inputs_with_special_tokens(
233
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
  ) -> List[int]:
235
- if token_ids_1 is None:
236
- return token_ids_0 + [self.eos_token_id]
237
- # We don't expect to process pairs, but leave the pair logic for API consistency
238
- return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
239
 
240
  def save_vocabulary(
241
  self, save_directory: str, filename_prefix: Optional[str] = None
242
- ) -> Tuple[str]:
243
  if not os.path.isdir(save_directory):
244
  logger.error(f"Vocabulary path ({save_directory}) should be a directory")
245
- return
246
 
247
  src_spm_fp = os.path.join(save_directory, "model.SRC")
248
  tgt_spm_fp = os.path.join(save_directory, "model.TGT")
249
  src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
250
  tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
251
 
252
- self._save_json(self.encoder, src_vocab_fp)
253
- self._save_json(self.decoder, tgt_vocab_fp)
254
-
255
- with open(src_spm_fp, "wb") as f:
256
- f.write(self.src_spm.serialized_model_proto())
257
 
258
- with open(tgt_spm_fp, "wb") as f:
259
- f.write(self.tgt_spm.serialized_model_proto())
 
260
 
261
  return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp
 
1
  import os
2
  import json
3
+ from functools import lru_cache
4
 
5
+ from transformers.utils import logging
6
  from typing import Dict, List, Optional, Union, Tuple
7
 
 
8
  from sentencepiece import SentencePieceProcessor
9
  from transformers.tokenization_utils import PreTrainedTokenizer
10
 
 
13
 
14
  SPIECE_UNDERLINE = "▁"
15
 
16
+ # Convert SPECIAL_TAGS to a frozen set for faster lookups
17
+ SPECIAL_TAGS = frozenset(
18
+ {
19
+ "asm_Beng",
20
+ "awa_Deva",
21
+ "ben_Beng",
22
+ "bho_Deva",
23
+ "brx_Deva",
24
+ "doi_Deva",
25
+ "eng_Latn",
26
+ "gom_Deva",
27
+ "gon_Deva",
28
+ "guj_Gujr",
29
+ "hin_Deva",
30
+ "hne_Deva",
31
+ "kan_Knda",
32
+ "kas_Arab",
33
+ "kas_Deva",
34
+ "kha_Latn",
35
+ "lus_Latn",
36
+ "mag_Deva",
37
+ "mai_Deva",
38
+ "mal_Mlym",
39
+ "mar_Deva",
40
+ "mni_Beng",
41
+ "mni_Mtei",
42
+ "npi_Deva",
43
+ "ory_Orya",
44
+ "pan_Guru",
45
+ "san_Deva",
46
+ "sat_Olck",
47
+ "snd_Arab",
48
+ "snd_Deva",
49
+ "tam_Taml",
50
+ "tel_Telu",
51
+ "urd_Arab",
52
+ "unr_Deva",
53
+ }
54
+ )
55
 
56
  VOCAB_FILES_NAMES = {
57
  "src_vocab_fp": "dict.SRC.json",
 
62
 
63
 
64
  class IndicTransTokenizer(PreTrainedTokenizer):
65
+ _added_tokens_encoder: Dict[str, int] = {}
66
+ _added_tokens_decoder: Dict[str, int] = {}
 
67
  vocab_files_names = VOCAB_FILES_NAMES
68
  model_input_names = ["input_ids", "attention_mask"]
69
 
 
80
  do_lower_case=False,
81
  **kwargs,
82
  ):
 
 
 
83
  self.src_vocab_fp = src_vocab_fp
84
  self.tgt_vocab_fp = tgt_vocab_fp
85
  self.src_spm_fp = src_spm_fp
86
  self.tgt_spm_fp = tgt_spm_fp
87
 
88
+ # Store token content directly instead of accessing .content
89
+ self.unk_token = (
90
+ hasattr(unk_token, "content") and unk_token.content or unk_token
91
+ )
92
+ self.pad_token = (
93
+ hasattr(pad_token, "content") and pad_token.content or pad_token
94
+ )
95
+ self.eos_token = (
96
+ hasattr(eos_token, "content") and eos_token.content or eos_token
97
+ )
98
+ self.bos_token = (
99
+ hasattr(bos_token, "content") and bos_token.content or bos_token
100
+ )
101
+
102
+ # Load vocabularies
103
+ self.src_encoder = self._load_json(self.src_vocab_fp)
104
+ self.tgt_encoder = self._load_json(self.tgt_vocab_fp)
105
 
106
+ # Validate tokens
107
+ if self.unk_token not in self.src_encoder:
108
  raise KeyError("<unk> token must be in vocab")
109
+ if self.pad_token not in self.src_encoder:
110
+ raise KeyError("<pad> token must be in vocab")
111
 
112
+ # Pre-compute reverse mappings
113
+ self.src_decoder = {v: k for k, v in self.src_encoder.items()}
114
+ self.tgt_decoder = {v: k for k, v in self.tgt_encoder.items()}
 
 
115
 
116
+ # Load SPM models
117
  self.src_spm = self._load_spm(self.src_spm_fp)
118
  self.tgt_spm = self._load_spm(self.tgt_spm_fp)
119
 
120
+ # Initialize current settings
121
+ self._switch_to_input_mode()
 
122
 
123
+ # Cache token IDs
124
+ self.unk_token_id = self.src_encoder[self.unk_token]
125
+ self.pad_token_id = self.src_encoder[self.pad_token]
126
+ self.eos_token_id = self.src_encoder[self.eos_token]
127
+ self.bos_token_id = self.src_encoder[self.bos_token]
128
 
129
  super().__init__(
130
  src_vocab_file=self.src_vocab_fp,
 
137
  **kwargs,
138
  )
139
 
140
+ def add_new_special_tags(self, new_tags: List[str]) -> None:
141
+ global SPECIAL_TAGS
142
+ SPECIAL_TAGS = frozenset(SPECIAL_TAGS | set(new_tags))
143
 
144
+ def _switch_to_input_mode(self) -> None:
145
+ self.spm = self.src_spm
146
  self.padding_side = "left"
147
+ self.encoder = self.src_encoder
148
+ self.decoder = self.src_decoder
149
+ self._tokenize = self._src_tokenize
150
 
151
+ def _switch_to_target_mode(self) -> None:
152
+ self.spm = self.tgt_spm
153
  self.padding_side = "right"
154
+ self.encoder = self.tgt_encoder
155
+ self.decoder = self.tgt_decoder
156
+ self._tokenize = self._tgt_tokenize
157
 
158
+ @staticmethod
159
+ def _load_spm(path: str) -> SentencePieceProcessor:
160
  return SentencePieceProcessor(model_file=path)
161
 
162
+ @staticmethod
163
+ def _save_json(data: Union[Dict, List], path: str) -> None:
164
  with open(path, "w", encoding="utf-8") as f:
165
  json.dump(data, f, indent=2)
166
 
167
+ @staticmethod
168
+ def _load_json(path: str) -> Union[Dict, List]:
169
  with open(path, "r", encoding="utf-8") as f:
170
  return json.load(f)
171
 
 
 
 
 
 
 
 
 
 
 
172
  @property
173
  def src_vocab_size(self) -> int:
174
+ return len(self.src_encoder)
175
 
176
  @property
177
  def tgt_vocab_size(self) -> int:
178
+ return len(self.tgt_encoder)
179
 
180
  def get_src_vocab(self) -> Dict[str, int]:
181
+ return dict(self.src_encoder, **self.added_tokens_encoder)
182
 
183
  def get_tgt_vocab(self) -> Dict[str, int]:
184
+ return dict(self.tgt_encoder, **self.added_tokens_decoder)
185
 
 
186
  def get_vocab(self) -> Dict[str, int]:
187
  return self.get_src_vocab()
188
 
 
189
  @property
190
  def vocab_size(self) -> int:
191
  return self.src_vocab_size
192
 
193
+ @lru_cache(maxsize=10240)
194
  def _convert_token_to_id(self, token: str) -> int:
195
+ return self.encoder.get(token, self.unk_token_id)
 
196
 
197
+ @lru_cache(maxsize=10240)
198
  def _convert_id_to_token(self, index: int) -> str:
199
+ return self.decoder.get(index, self.unk_token)
 
200
 
201
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
202
+ return "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
 
 
 
203
 
204
+ def _src_tokenize(self, text: str) -> List[str]:
205
+ src_lang, tgt_lang, text = text.split(" ", 2)
206
+ return [src_lang, tgt_lang] + self.spm.EncodeAsPieces(text)
207
 
208
+ def _tgt_tokenize(self, text: str) -> List[str]:
209
+ return self.spm.EncodeAsPieces(text)
 
 
 
 
 
210
 
211
+ def _decode(
212
+ self,
213
+ token_ids: Union[int, List[int]],
214
+ skip_special_tokens: bool = False,
215
+ clean_up_tokenization_spaces: bool = None,
216
+ spaces_between_special_tokens: bool = True,
217
+ **kwargs,
218
+ ) -> str:
219
+ self._switch_to_target_mode()
220
+ decoded_token_ids = super()._decode(
221
+ token_ids=token_ids,
222
+ skip_special_tokens=skip_special_tokens,
223
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
224
+ spaces_between_special_tokens=spaces_between_special_tokens,
225
+ **kwargs,
226
  )
227
+ self._switch_to_input_mode()
228
+ return decoded_token_ids
 
 
 
 
 
 
 
 
229
 
230
  def build_inputs_with_special_tokens(
231
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
232
  ) -> List[int]:
233
+ return token_ids_0 + [self.eos_token_id]
 
 
 
234
 
235
  def save_vocabulary(
236
  self, save_directory: str, filename_prefix: Optional[str] = None
237
+ ) -> Tuple[str, ...]:
238
  if not os.path.isdir(save_directory):
239
  logger.error(f"Vocabulary path ({save_directory}) should be a directory")
240
+ return ()
241
 
242
  src_spm_fp = os.path.join(save_directory, "model.SRC")
243
  tgt_spm_fp = os.path.join(save_directory, "model.TGT")
244
  src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
245
  tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
246
 
247
+ self._save_json(self.src_encoder, src_vocab_fp)
248
+ self._save_json(self.tgt_encoder, tgt_vocab_fp)
 
 
 
249
 
250
+ for fp, spm in [(src_spm_fp, self.src_spm), (tgt_spm_fp, self.tgt_spm)]:
251
+ with open(fp, "wb") as f:
252
+ f.write(spm.serialized_model_proto())
253
 
254
  return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp