add training and tokenization scripts
Browse files- run-t5v1_1-small.sh +0 -1
- run-t5v1_1-small.sh +26 -0
- run_t5_mlm_flax.py +684 -0
- t5_tokenizer_model.py +140 -0
run-t5v1_1-small.sh
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
../run-t5v1_1-small.sh
|
|
|
|
run-t5v1_1-small.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export model_dir=arabic-t5-small
|
2 |
+
export train_batch_size=48
|
3 |
+
export eval_batch_size=96
|
4 |
+
|
5 |
+
python ./run_t5_mlm_flax.py \
|
6 |
+
--model_type t5 \
|
7 |
+
--config_name ${model_dir} \
|
8 |
+
--tokenizer_name ${model_dir} \
|
9 |
+
--use_fast_tokenizer True \
|
10 |
+
--dtype float32 \
|
11 |
+
--max_seq_length 512 \
|
12 |
+
--preprocessing_num_workers 96 \
|
13 |
+
--output_dir ${model_dir} \
|
14 |
+
--overwrite_output_dir True \
|
15 |
+
--do_train \
|
16 |
+
--per_device_train_batch_size ${train_batch_size} \
|
17 |
+
--per_device_eval_batch_size ${eval_batch_size} \
|
18 |
+
--learning_rate 1e-2 \
|
19 |
+
--num_train_epochs 1 \
|
20 |
+
--logging_steps 100 \
|
21 |
+
--eval_steps 1000 \
|
22 |
+
--save_steps 1000 \
|
23 |
+
--seed 12 \
|
24 |
+
--adafactor True \
|
25 |
+
--push_to_hub \
|
26 |
+
--cache_dir ./training_cache \
|
run_t5_mlm_flax.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
|
18 |
+
|
19 |
+
Here is the full list of checkpoints on the hub that can be pretrained by this script:
|
20 |
+
https://huggingface.co/models?filter=t5
|
21 |
+
"""
|
22 |
+
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import sys
|
26 |
+
import time
|
27 |
+
from dataclasses import dataclass, field
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Dict, List, Optional
|
30 |
+
|
31 |
+
import numpy as np
|
32 |
+
from datasets import load_dataset, concatenate_datasets, load_from_disk
|
33 |
+
from tqdm import tqdm
|
34 |
+
|
35 |
+
import flax
|
36 |
+
import jax
|
37 |
+
import jax.numpy as jnp
|
38 |
+
import optax
|
39 |
+
from flax import jax_utils, traverse_util
|
40 |
+
from flax.training import train_state
|
41 |
+
from flax.training.checkpoints import save_checkpoint
|
42 |
+
from flax.training.common_utils import get_metrics, onehot, shard
|
43 |
+
from transformers import (
|
44 |
+
CONFIG_MAPPING,
|
45 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
46 |
+
BatchEncoding,
|
47 |
+
FlaxT5ForConditionalGeneration,
|
48 |
+
HfArgumentParser,
|
49 |
+
PreTrainedTokenizerBase,
|
50 |
+
T5Config,
|
51 |
+
T5TokenizerFast,
|
52 |
+
TrainingArguments,
|
53 |
+
is_tensorboard_available,
|
54 |
+
set_seed,
|
55 |
+
)
|
56 |
+
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
57 |
+
|
58 |
+
|
59 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
60 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
61 |
+
|
62 |
+
|
63 |
+
@dataclass
|
64 |
+
class ModelArguments:
|
65 |
+
"""
|
66 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
67 |
+
"""
|
68 |
+
|
69 |
+
model_name_or_path: Optional[str] = field(
|
70 |
+
default=None,
|
71 |
+
metadata={
|
72 |
+
"help": "The model checkpoint for weights initialization."
|
73 |
+
"Don't set if you want to train a model from scratch."
|
74 |
+
},
|
75 |
+
)
|
76 |
+
model_type: Optional[str] = field(
|
77 |
+
default=None,
|
78 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
79 |
+
)
|
80 |
+
config_name: Optional[str] = field(
|
81 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
82 |
+
)
|
83 |
+
tokenizer_name: Optional[str] = field(
|
84 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
85 |
+
)
|
86 |
+
cache_dir: Optional[str] = field(
|
87 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
88 |
+
)
|
89 |
+
use_fast_tokenizer: bool = field(
|
90 |
+
default=True,
|
91 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
92 |
+
)
|
93 |
+
dtype: Optional[str] = field(
|
94 |
+
default="float32",
|
95 |
+
metadata={
|
96 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
97 |
+
},
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
@dataclass
|
102 |
+
class DataTrainingArguments:
|
103 |
+
"""
|
104 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
105 |
+
"""
|
106 |
+
|
107 |
+
overwrite_cache: bool = field(
|
108 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
109 |
+
)
|
110 |
+
max_seq_length: Optional[int] = field(
|
111 |
+
default=None,
|
112 |
+
metadata={
|
113 |
+
"help": "The maximum total input sequence length after tokenization and masking. Sequences longer than this will be truncated. Default to the max input length of the model."
|
114 |
+
},
|
115 |
+
)
|
116 |
+
preprocessing_num_workers: Optional[int] = field(
|
117 |
+
default=None,
|
118 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
119 |
+
)
|
120 |
+
mlm_probability: float = field(
|
121 |
+
default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
|
122 |
+
)
|
123 |
+
mean_noise_span_length: float = field(
|
124 |
+
default=3.0,
|
125 |
+
metadata={"help": "Mean span length of masked tokens"},
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
|
130 |
+
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
|
131 |
+
|
132 |
+
Training parameters to avoid padding with random_spans_noise_mask.
|
133 |
+
When training a model with random_spans_noise_mask, we would like to set the other
|
134 |
+
training hyperparmeters in a way that avoids padding.
|
135 |
+
This function helps us compute these hyperparameters.
|
136 |
+
We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
|
137 |
+
and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
|
138 |
+
This function tells us the required number of tokens in the raw example (for split_tokens())
|
139 |
+
as well as the length of the encoded targets. Note that this function assumes
|
140 |
+
the inputs and targets will have EOS appended and includes that in the reported length.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
inputs_length: an integer - desired length of the tokenized inputs sequence
|
144 |
+
noise_density: a float
|
145 |
+
mean_noise_span_length: a float
|
146 |
+
Returns:
|
147 |
+
tokens_length: length of original text in tokens
|
148 |
+
targets_length: an integer - length in tokens of encoded targets sequence
|
149 |
+
"""
|
150 |
+
|
151 |
+
def _tokens_length_to_inputs_length_targets_length(tokens_length):
|
152 |
+
num_noise_tokens = int(round(tokens_length * noise_density))
|
153 |
+
num_nonnoise_tokens = tokens_length - num_noise_tokens
|
154 |
+
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
|
155 |
+
# inputs contain all nonnoise tokens, sentinels for all noise spans
|
156 |
+
# and one EOS token.
|
157 |
+
_input_length = num_nonnoise_tokens + num_noise_spans + 1
|
158 |
+
_output_length = num_noise_tokens + num_noise_spans + 1
|
159 |
+
return _input_length, _output_length
|
160 |
+
|
161 |
+
tokens_length = inputs_length
|
162 |
+
|
163 |
+
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
|
164 |
+
tokens_length += 1
|
165 |
+
|
166 |
+
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
|
167 |
+
|
168 |
+
# minor hack to get the targets length to be equal to inputs length
|
169 |
+
# which is more likely to have been set to a nice round number.
|
170 |
+
if noise_density == 0.5 and targets_length > inputs_length:
|
171 |
+
tokens_length -= 1
|
172 |
+
targets_length -= 1
|
173 |
+
return tokens_length, targets_length
|
174 |
+
|
175 |
+
|
176 |
+
@flax.struct.dataclass
|
177 |
+
class FlaxDataCollatorForT5MLM:
|
178 |
+
"""
|
179 |
+
Data collator used for T5 span-masked language modeling.
|
180 |
+
It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
|
181 |
+
For more information on how T5 span-masked language modeling works, one can take a look
|
182 |
+
at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
|
183 |
+
or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
|
184 |
+
|
185 |
+
Args:
|
186 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
187 |
+
The tokenizer used for encoding the data.
|
188 |
+
noise_density (:obj:`float`):
|
189 |
+
The probability with which to (randomly) mask tokens in the input.
|
190 |
+
mean_noise_span_length (:obj:`float`):
|
191 |
+
The average span length of the masked tokens.
|
192 |
+
input_length (:obj:`int`):
|
193 |
+
The expected input length after masking.
|
194 |
+
target_length (:obj:`int`):
|
195 |
+
The expected target length after masking.
|
196 |
+
pad_token_id: (:obj:`int`):
|
197 |
+
The pad token id of the model
|
198 |
+
decoder_start_token_id: (:obj:`int):
|
199 |
+
The decoder start token id of the model
|
200 |
+
"""
|
201 |
+
|
202 |
+
tokenizer: PreTrainedTokenizerBase
|
203 |
+
noise_density: float
|
204 |
+
mean_noise_span_length: float
|
205 |
+
input_length: int
|
206 |
+
target_length: int
|
207 |
+
pad_token_id: int
|
208 |
+
decoder_start_token_id: int
|
209 |
+
|
210 |
+
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
|
211 |
+
|
212 |
+
# convert list to dict and tensorize input
|
213 |
+
batch = BatchEncoding(
|
214 |
+
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
215 |
+
)
|
216 |
+
|
217 |
+
input_ids = batch["input_ids"]
|
218 |
+
batch_size, expandend_input_length = input_ids.shape
|
219 |
+
|
220 |
+
mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
|
221 |
+
labels_mask = ~mask_indices
|
222 |
+
|
223 |
+
input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
|
224 |
+
labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
|
225 |
+
|
226 |
+
batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
|
227 |
+
batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
|
228 |
+
|
229 |
+
if batch["input_ids"].shape[-1] != self.input_length:
|
230 |
+
raise ValueError(
|
231 |
+
f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but should be {self.target_length}."
|
232 |
+
)
|
233 |
+
|
234 |
+
if batch["labels"].shape[-1] != self.target_length:
|
235 |
+
raise ValueError(
|
236 |
+
f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be {self.target_length}."
|
237 |
+
)
|
238 |
+
|
239 |
+
# to check that tokens are correctly proprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
|
240 |
+
batch["decoder_input_ids"] = shift_tokens_right(
|
241 |
+
batch["labels"], self.pad_token_id, self.decoder_start_token_id
|
242 |
+
)
|
243 |
+
|
244 |
+
return batch
|
245 |
+
|
246 |
+
def create_sentinel_ids(self, mask_indices):
|
247 |
+
"""
|
248 |
+
Sentinel ids creation given the indices that should be masked.
|
249 |
+
The start indices of each mask are replaced by the sentinel ids in increasing
|
250 |
+
order. Consecutive mask indices to be deleted are replaced with `-1`.
|
251 |
+
"""
|
252 |
+
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
|
253 |
+
start_indices[:, 0] = mask_indices[:, 0]
|
254 |
+
|
255 |
+
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
|
256 |
+
sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + self.tokenizer.vocab_size - 1), 0)
|
257 |
+
sentinel_ids -= mask_indices - start_indices
|
258 |
+
|
259 |
+
return sentinel_ids
|
260 |
+
|
261 |
+
def filter_input_ids(self, input_ids, sentinel_ids):
|
262 |
+
"""
|
263 |
+
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
|
264 |
+
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
|
265 |
+
"""
|
266 |
+
batch_size = input_ids.shape[0]
|
267 |
+
|
268 |
+
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
269 |
+
input_ids = input_ids_full[input_ids_full > 0].reshape((batch_size, -1))
|
270 |
+
input_ids = np.concatenate(
|
271 |
+
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
|
272 |
+
)
|
273 |
+
return input_ids
|
274 |
+
|
275 |
+
def random_spans_noise_mask(self, length):
|
276 |
+
|
277 |
+
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
278 |
+
|
279 |
+
Noise mask consisting of random spans of noise tokens.
|
280 |
+
The number of noise tokens and the number of noise spans and non-noise spans
|
281 |
+
are determined deterministically as follows:
|
282 |
+
num_noise_tokens = round(length * noise_density)
|
283 |
+
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
284 |
+
Spans alternate between non-noise and noise, beginning with non-noise.
|
285 |
+
Subject to the above restrictions, all masks are equally likely.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
length: an int32 scalar (length of the incoming token sequence)
|
289 |
+
noise_density: a float - approximate density of output mask
|
290 |
+
mean_noise_span_length: a number
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
a boolean tensor with shape [length]
|
294 |
+
"""
|
295 |
+
|
296 |
+
orig_length = length
|
297 |
+
|
298 |
+
num_noise_tokens = int(np.round(length * self.noise_density))
|
299 |
+
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
300 |
+
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
301 |
+
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
|
302 |
+
|
303 |
+
# avoid degeneracy by ensuring positive number of noise spans
|
304 |
+
num_noise_spans = max(num_noise_spans, 1)
|
305 |
+
num_nonnoise_tokens = length - num_noise_tokens
|
306 |
+
|
307 |
+
# pick the lengths of the noise spans and the non-noise spans
|
308 |
+
def _random_segmentation(num_items, num_segments):
|
309 |
+
"""Partition a sequence of items randomly into non-empty segments.
|
310 |
+
Args:
|
311 |
+
num_items: an integer scalar > 0
|
312 |
+
num_segments: an integer scalar in [1, num_items]
|
313 |
+
Returns:
|
314 |
+
a Tensor with shape [num_segments] containing positive integers that add
|
315 |
+
up to num_items
|
316 |
+
"""
|
317 |
+
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
318 |
+
np.random.shuffle(mask_indices)
|
319 |
+
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
320 |
+
segment_id = np.cumsum(first_in_segment)
|
321 |
+
segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id))
|
322 |
+
return segment_length
|
323 |
+
|
324 |
+
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
325 |
+
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
|
326 |
+
|
327 |
+
interleaved_span_lengths = np.reshape(
|
328 |
+
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
|
329 |
+
)
|
330 |
+
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
331 |
+
span_start_indicator = np.zeros((length,), dtype=np.int8)
|
332 |
+
span_start_indicator[span_starts] = True
|
333 |
+
span_num = np.cumsum(span_start_indicator)
|
334 |
+
is_noise = np.equal(span_num % 2, 1)
|
335 |
+
|
336 |
+
return is_noise[:orig_length]
|
337 |
+
|
338 |
+
|
339 |
+
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
340 |
+
num_samples = len(samples_idx)
|
341 |
+
samples_to_remove = num_samples % batch_size
|
342 |
+
|
343 |
+
if samples_to_remove != 0:
|
344 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
345 |
+
sections_split = num_samples // batch_size
|
346 |
+
batch_idx = np.split(samples_idx, sections_split)
|
347 |
+
return batch_idx
|
348 |
+
|
349 |
+
|
350 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
351 |
+
summary_writer.scalar("train_time", train_time, step)
|
352 |
+
|
353 |
+
train_metrics = get_metrics(train_metrics)
|
354 |
+
for key, vals in train_metrics.items():
|
355 |
+
tag = f"train_{key}"
|
356 |
+
for i, val in enumerate(vals):
|
357 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
358 |
+
|
359 |
+
|
360 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
361 |
+
for metric_name, value in eval_metrics.items():
|
362 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
363 |
+
|
364 |
+
|
365 |
+
if __name__ == "__main__":
|
366 |
+
# See all possible arguments in src/transformers/training_args.py
|
367 |
+
# or by passing the --help flag to this script.
|
368 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
369 |
+
|
370 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
371 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
372 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
373 |
+
# let's parse it to get our arguments.
|
374 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
375 |
+
else:
|
376 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
377 |
+
|
378 |
+
if (
|
379 |
+
os.path.exists(training_args.output_dir)
|
380 |
+
and os.listdir(training_args.output_dir)
|
381 |
+
and training_args.do_train
|
382 |
+
and not training_args.overwrite_output_dir
|
383 |
+
):
|
384 |
+
raise ValueError(
|
385 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
386 |
+
"Use --overwrite_output_dir to overcome."
|
387 |
+
)
|
388 |
+
|
389 |
+
# Setup logging
|
390 |
+
logging.basicConfig(
|
391 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
392 |
+
level="NOTSET",
|
393 |
+
datefmt="[%X]",
|
394 |
+
)
|
395 |
+
|
396 |
+
# Log on each process the small summary:
|
397 |
+
logger = logging.getLogger(__name__)
|
398 |
+
|
399 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
400 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
401 |
+
|
402 |
+
# Set seed before initializing model.
|
403 |
+
set_seed(training_args.seed)
|
404 |
+
|
405 |
+
# Load pretrained model and tokenizer
|
406 |
+
|
407 |
+
if model_args.tokenizer_name:
|
408 |
+
tokenizer = T5TokenizerFast.from_pretrained(
|
409 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
410 |
+
)
|
411 |
+
elif model_args.model_name_or_path:
|
412 |
+
tokenizer = T5TokenizerFast.from_pretrained(
|
413 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
414 |
+
)
|
415 |
+
else:
|
416 |
+
raise ValueError(
|
417 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
418 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
419 |
+
)
|
420 |
+
|
421 |
+
if model_args.config_name:
|
422 |
+
config = T5Config.from_pretrained(
|
423 |
+
model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
|
424 |
+
)
|
425 |
+
elif model_args.model_name_or_path:
|
426 |
+
config = T5Config.from_pretrained(
|
427 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)
|
428 |
+
)
|
429 |
+
else:
|
430 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
431 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
432 |
+
|
433 |
+
|
434 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
435 |
+
|
436 |
+
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
437 |
+
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
438 |
+
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
|
439 |
+
expanded_inputs_length, targets_length = compute_input_and_target_lengths(
|
440 |
+
inputs_length=max_seq_length,
|
441 |
+
noise_density=data_args.mlm_probability,
|
442 |
+
mean_noise_span_length=data_args.mean_noise_span_length,
|
443 |
+
)
|
444 |
+
|
445 |
+
# load the tokenized and grouped dataset
|
446 |
+
tokenized_datasets = load_from_disk("./training_cache")
|
447 |
+
|
448 |
+
# Enable tensorboard only on the master node
|
449 |
+
has_tensorboard = is_tensorboard_available()
|
450 |
+
if has_tensorboard and jax.process_index() == 0:
|
451 |
+
try:
|
452 |
+
from flax.metrics.tensorboard import SummaryWriter
|
453 |
+
|
454 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
455 |
+
except ImportError as ie:
|
456 |
+
has_tensorboard = False
|
457 |
+
logger.warning(
|
458 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
459 |
+
)
|
460 |
+
else:
|
461 |
+
logger.warning(
|
462 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
463 |
+
"Please run pip install tensorboard to enable."
|
464 |
+
)
|
465 |
+
|
466 |
+
# Initialize our training
|
467 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
468 |
+
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
469 |
+
|
470 |
+
logger.info(
|
471 |
+
f"JAX devices:\n{jax.devices()}\nNum devices: {jax.device_count()}\nBackend: {jax.lib.xla_bridge.get_backend().platform}"
|
472 |
+
)
|
473 |
+
|
474 |
+
logger.info(
|
475 |
+
"\n==================================================Initializing the model==================================================\n"
|
476 |
+
)
|
477 |
+
|
478 |
+
if model_args.model_name_or_path:
|
479 |
+
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
480 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
484 |
+
|
485 |
+
logger.info(
|
486 |
+
"\n==================================================Done!==================================================\n"
|
487 |
+
)
|
488 |
+
|
489 |
+
# Data collator
|
490 |
+
# This one will take care of randomly masking the tokens.
|
491 |
+
data_collator = FlaxDataCollatorForT5MLM(
|
492 |
+
tokenizer=tokenizer,
|
493 |
+
noise_density=data_args.mlm_probability,
|
494 |
+
mean_noise_span_length=data_args.mean_noise_span_length,
|
495 |
+
input_length=max_seq_length,
|
496 |
+
target_length=targets_length,
|
497 |
+
pad_token_id=model.config.pad_token_id,
|
498 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
499 |
+
)
|
500 |
+
|
501 |
+
# Store some constant
|
502 |
+
num_epochs = int(training_args.num_train_epochs)
|
503 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
504 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
505 |
+
|
506 |
+
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
507 |
+
|
508 |
+
# Create learning rate schedule
|
509 |
+
warmup_steps = num_train_steps * 5 // 100
|
510 |
+
warmup_fn = optax.linear_schedule(
|
511 |
+
init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
|
512 |
+
)
|
513 |
+
decay_fn = optax.linear_schedule(
|
514 |
+
init_value=training_args.learning_rate,
|
515 |
+
end_value=0,
|
516 |
+
transition_steps=num_train_steps - warmup_steps,
|
517 |
+
)
|
518 |
+
linear_decay_lr_schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
|
519 |
+
|
520 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
521 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
522 |
+
# mask boolean with the same structure as the parameters.
|
523 |
+
# The mask is True for parameters that should be decayed.
|
524 |
+
def decay_mask_fn(params):
|
525 |
+
flat_params = traverse_util.flatten_dict(params)
|
526 |
+
flat_mask = {
|
527 |
+
path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")])
|
528 |
+
for path in flat_params
|
529 |
+
}
|
530 |
+
return traverse_util.unflatten_dict(flat_mask)
|
531 |
+
|
532 |
+
# create adam optimizer
|
533 |
+
if training_args.adafactor:
|
534 |
+
# We use the default parameters here to initialize adafactor,
|
535 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
536 |
+
optimizer = optax.adafactor(
|
537 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
538 |
+
)
|
539 |
+
else:
|
540 |
+
optimizer = optax.adamw(
|
541 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
542 |
+
b1=training_args.adam_beta1,
|
543 |
+
b2=training_args.adam_beta2,
|
544 |
+
weight_decay=training_args.weight_decay,
|
545 |
+
mask=decay_mask_fn,
|
546 |
+
)
|
547 |
+
|
548 |
+
# Setup train state
|
549 |
+
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
550 |
+
|
551 |
+
# Define gradient update step fn
|
552 |
+
def train_step(state, batch, dropout_rng):
|
553 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
554 |
+
|
555 |
+
def loss_fn(params):
|
556 |
+
labels = batch.pop("labels")
|
557 |
+
|
558 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
559 |
+
|
560 |
+
# compute loss
|
561 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
|
562 |
+
|
563 |
+
return loss
|
564 |
+
|
565 |
+
grad_fn = jax.value_and_grad(loss_fn)
|
566 |
+
loss, grad = grad_fn(state.params)
|
567 |
+
grad = jax.lax.pmean(grad, "batch")
|
568 |
+
new_state = state.apply_gradients(grads=grad)
|
569 |
+
|
570 |
+
metrics = jax.lax.pmean(
|
571 |
+
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
572 |
+
)
|
573 |
+
|
574 |
+
return new_state, metrics, new_dropout_rng
|
575 |
+
|
576 |
+
# Create parallel version of the train step
|
577 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
578 |
+
|
579 |
+
# Define eval fn
|
580 |
+
def eval_step(params, batch):
|
581 |
+
labels = batch.pop("labels")
|
582 |
+
|
583 |
+
logits = model(**batch, params=params, train=False)[0]
|
584 |
+
|
585 |
+
# compute loss
|
586 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
587 |
+
|
588 |
+
# compute accuracy
|
589 |
+
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
|
590 |
+
|
591 |
+
# summarize metrics
|
592 |
+
metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
|
593 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
594 |
+
|
595 |
+
return metrics
|
596 |
+
|
597 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
598 |
+
|
599 |
+
# Replicate the train state on each device
|
600 |
+
state = jax_utils.replicate(state)
|
601 |
+
|
602 |
+
train_time = 0
|
603 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
604 |
+
for epoch in epochs:
|
605 |
+
# ======================== Training ================================
|
606 |
+
train_start = time.time()
|
607 |
+
train_metrics = []
|
608 |
+
|
609 |
+
# Create sampling rng
|
610 |
+
rng, input_rng = jax.random.split(rng)
|
611 |
+
|
612 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
613 |
+
num_train_samples = len(tokenized_datasets["train"])
|
614 |
+
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
615 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
616 |
+
|
617 |
+
# Gather the indexes for creating the batch and do a training step
|
618 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
619 |
+
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
620 |
+
model_inputs = data_collator(samples)
|
621 |
+
|
622 |
+
# Model forward
|
623 |
+
model_inputs = shard(model_inputs.data)
|
624 |
+
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
625 |
+
train_metrics.append(train_metric)
|
626 |
+
|
627 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
628 |
+
|
629 |
+
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
630 |
+
# Save metrics
|
631 |
+
train_metric = jax_utils.unreplicate(train_metric)
|
632 |
+
train_time += time.time() - train_start
|
633 |
+
if has_tensorboard and jax.process_index() == 0:
|
634 |
+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
635 |
+
|
636 |
+
epochs.write(
|
637 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
638 |
+
)
|
639 |
+
|
640 |
+
train_metrics = []
|
641 |
+
|
642 |
+
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
643 |
+
# ======================== Evaluating ==============================
|
644 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
645 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
646 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
647 |
+
|
648 |
+
eval_metrics = []
|
649 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
650 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
651 |
+
model_inputs = data_collator(samples)
|
652 |
+
|
653 |
+
# Model forward
|
654 |
+
model_inputs = shard(model_inputs.data)
|
655 |
+
metrics = p_eval_step(state.params, model_inputs)
|
656 |
+
eval_metrics.append(metrics)
|
657 |
+
|
658 |
+
# get eval metrics
|
659 |
+
eval_metrics = get_metrics(eval_metrics)
|
660 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
661 |
+
|
662 |
+
# Update progress bar
|
663 |
+
epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
|
664 |
+
|
665 |
+
# Save metrics
|
666 |
+
if has_tensorboard and jax.process_index() == 0:
|
667 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
668 |
+
|
669 |
+
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
670 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
671 |
+
if jax.process_index() == 0:
|
672 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
673 |
+
save_checkpoint(
|
674 |
+
ckpt_dir=training_args.output_dir,
|
675 |
+
target=jax_utils.unreplicate(state),
|
676 |
+
step=cur_step,
|
677 |
+
overwrite=True,
|
678 |
+
)
|
679 |
+
model.save_pretrained(
|
680 |
+
training_args.output_dir,
|
681 |
+
params=params,
|
682 |
+
push_to_hub=training_args.push_to_hub,
|
683 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
684 |
+
)
|
t5_tokenizer_model.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import json
|
3 |
+
from typing import Iterator, List, Union
|
4 |
+
|
5 |
+
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers
|
6 |
+
from tokenizers.implementations.base_tokenizer import BaseTokenizer
|
7 |
+
from tokenizers.models import Unigram
|
8 |
+
from tokenizers.processors import TemplateProcessing
|
9 |
+
|
10 |
+
|
11 |
+
class SentencePieceUnigramTokenizer(BaseTokenizer):
|
12 |
+
"""
|
13 |
+
This class is a copy of `DeDLOC's tokenizer implementation <https://github.com/yandex-research/DeDLOC/blob/main/sahajbert/tokenizer/tokenizer_model.py>`__ .
|
14 |
+
|
15 |
+
Custom SentencePiece Unigram Tokenizer with NMT, NKFC, spaces and lower-casing characters normalization
|
16 |
+
Represents the Unigram algorithm, with the pretokenization used by SentencePiece
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
replacement: str = "▁",
|
22 |
+
add_prefix_space: bool = True,
|
23 |
+
unk_token: Union[str, AddedToken] = "<unk>",
|
24 |
+
eos_token: Union[str, AddedToken] = "</s>",
|
25 |
+
pad_token: Union[str, AddedToken] = "<pad>",
|
26 |
+
):
|
27 |
+
self.special_tokens = {
|
28 |
+
"pad": {"id": 0, "token": pad_token},
|
29 |
+
"eos": {"id": 1, "token": eos_token},
|
30 |
+
"unk": {"id": 2, "token": unk_token},
|
31 |
+
}
|
32 |
+
|
33 |
+
self.special_tokens_list = [None] * len(self.special_tokens)
|
34 |
+
for token_dict in self.special_tokens.values():
|
35 |
+
self.special_tokens_list[token_dict["id"]] = token_dict["token"]
|
36 |
+
|
37 |
+
tokenizer = Tokenizer(Unigram())
|
38 |
+
|
39 |
+
# the following regexes are taken directly from https://github.com/aub-mind/arabert/blob/f92f06a29804f74878e2d1e39ea57fba8dcb0eac/preprocess.py
|
40 |
+
url = " [رابط] "
|
41 |
+
email = " [بريد] "
|
42 |
+
usr = " [مستخدم] "
|
43 |
+
|
44 |
+
url_regexes = [
|
45 |
+
r"(http(s)?:\/\/.)?(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)",
|
46 |
+
r"@(https?|ftp)://(-\.)?([^\s/?\.#-]+\.?)+(/[^\s]*)?$@iS",
|
47 |
+
r"http[s]?://[a-zA-Z0-9_\-./~\?=%&]+",
|
48 |
+
r"www[a-zA-Z0-9_\-?=%&/.~]+",
|
49 |
+
r"[a-zA-Z]+\.com",
|
50 |
+
r"(?=http)[^\s]+",
|
51 |
+
r"(?=www)[^\s]+",
|
52 |
+
r"://",
|
53 |
+
]
|
54 |
+
|
55 |
+
email_regexes = [r"[\w-]+@([\w-]+\.)+[\w-]+", r"\S+@\S+"]
|
56 |
+
|
57 |
+
user_mention_regex = r"@[\w\d]+"
|
58 |
+
|
59 |
+
tokenizer.normalizer = normalizers.Sequence(
|
60 |
+
[
|
61 |
+
normalizers.Nmt(),
|
62 |
+
normalizers.NFKC(),
|
63 |
+
# remove links, emails, user mentions ans hashtags
|
64 |
+
*[normalizers.Replace(Regex(r), url) for r in url_regexes],
|
65 |
+
*[normalizers.Replace(Regex(r), email) for r in email_regexes],
|
66 |
+
normalizers.Replace(Regex(user_mention_regex), usr),
|
67 |
+
# remove html
|
68 |
+
normalizers.Replace(Regex("<br />"), " "),
|
69 |
+
normalizers.Replace(Regex("</?[^>]+>"), " "),
|
70 |
+
# remove extra white space
|
71 |
+
normalizers.Replace(Regex(" {2,}"), " "),
|
72 |
+
normalizers.Lowercase(),
|
73 |
+
]
|
74 |
+
)
|
75 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
76 |
+
[
|
77 |
+
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
|
78 |
+
pre_tokenizers.Digits(individual_digits=True),
|
79 |
+
pre_tokenizers.Punctuation(),
|
80 |
+
]
|
81 |
+
)
|
82 |
+
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
83 |
+
|
84 |
+
tokenizer.post_processor = TemplateProcessing(
|
85 |
+
single=f"$A {self.special_tokens['eos']['token']}",
|
86 |
+
special_tokens=[(self.special_tokens["eos"]["token"], self.special_tokens["eos"]["id"])],
|
87 |
+
)
|
88 |
+
|
89 |
+
parameters = {
|
90 |
+
"model": "SentencePieceUnigram",
|
91 |
+
"replacement": replacement,
|
92 |
+
"add_prefix_space": add_prefix_space,
|
93 |
+
}
|
94 |
+
|
95 |
+
super().__init__(tokenizer, parameters)
|
96 |
+
|
97 |
+
def train(
|
98 |
+
self,
|
99 |
+
files: Union[str, List[str]],
|
100 |
+
vocab_size: int = 8000,
|
101 |
+
show_progress: bool = True,
|
102 |
+
):
|
103 |
+
"""Train the model using the given files"""
|
104 |
+
|
105 |
+
trainer = trainers.UnigramTrainer(
|
106 |
+
vocab_size=vocab_size,
|
107 |
+
special_tokens=self.special_tokens_list,
|
108 |
+
show_progress=show_progress,
|
109 |
+
)
|
110 |
+
|
111 |
+
if isinstance(files, str):
|
112 |
+
files = [files]
|
113 |
+
self._tokenizer.train(files, trainer=trainer)
|
114 |
+
|
115 |
+
self.add_unk_id()
|
116 |
+
|
117 |
+
def train_from_iterator(
|
118 |
+
self,
|
119 |
+
iterator: Union[Iterator[str], Iterator[Iterator[str]]],
|
120 |
+
vocab_size: int = 8000,
|
121 |
+
show_progress: bool = True,
|
122 |
+
):
|
123 |
+
"""Train the model using the given iterator"""
|
124 |
+
|
125 |
+
trainer = trainers.UnigramTrainer(
|
126 |
+
vocab_size=vocab_size,
|
127 |
+
special_tokens=self.special_tokens_list,
|
128 |
+
show_progress=show_progress,
|
129 |
+
)
|
130 |
+
|
131 |
+
self._tokenizer.train_from_iterator(iterator, trainer=trainer)
|
132 |
+
|
133 |
+
self.add_unk_id()
|
134 |
+
|
135 |
+
def add_unk_id(self):
|
136 |
+
tokenizer_json = json.loads(self._tokenizer.to_str())
|
137 |
+
|
138 |
+
tokenizer_json["model"]["unk_id"] = self.special_tokens["unk"]["id"]
|
139 |
+
|
140 |
+
self._tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))
|