wilbin commited on
Commit
5a8a304
·
verified ·
1 Parent(s): 147c12e
Files changed (1) hide show
  1. training +1012 -0
training ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pretraining the library models for T5-like span-masked language modeling on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be pretrained by this script:
20
+ https://huggingface.co/models?filter=t5
21
+ """
22
+ import json
23
+ import logging
24
+ import math
25
+ import os
26
+ import sys
27
+ import time
28
+ import warnings
29
+ from dataclasses import asdict, dataclass, field
30
+
31
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
32
+ from enum import Enum
33
+ from itertools import chain
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional
36
+
37
+ import flax
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import numpy as np
41
+ import optax
42
+ from datasets import load_dataset
43
+ from flax import jax_utils, traverse_util
44
+ from flax.jax_utils import pad_shard_unpad
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard
47
+ from huggingface_hub import Repository, create_repo
48
+ from tqdm import tqdm
49
+
50
+ from transformers import (
51
+ CONFIG_MAPPING,
52
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
53
+ AutoTokenizer,
54
+ BatchEncoding,
55
+ FlaxT5ForConditionalGeneration,
56
+ HfArgumentParser,
57
+ PreTrainedTokenizerBase,
58
+ T5Config,
59
+ is_tensorboard_available,
60
+ set_seed,
61
+ )
62
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
63
+ from transformers.utils import send_example_telemetry
64
+
65
+
66
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
67
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
68
+
69
+
70
+ @dataclass
71
+ class TrainingArguments:
72
+ output_dir: str = field(
73
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
74
+ )
75
+ overwrite_output_dir: bool = field(
76
+ default=False,
77
+ metadata={
78
+ "help": (
79
+ "Overwrite the content of the output directory. "
80
+ "Use this to continue training if output_dir points to a checkpoint directory."
81
+ )
82
+ },
83
+ )
84
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
85
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
86
+ per_device_train_batch_size: int = field(
87
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
88
+ )
89
+ per_device_eval_batch_size: int = field(
90
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
91
+ )
92
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
93
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
94
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
95
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
96
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
97
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
98
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
99
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
100
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
101
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
102
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
103
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
104
+ push_to_hub: bool = field(
105
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
106
+ )
107
+ hub_model_id: str = field(
108
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
109
+ )
110
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
111
+
112
+ def __post_init__(self):
113
+ if self.output_dir is not None:
114
+ self.output_dir = os.path.expanduser(self.output_dir)
115
+
116
+ def to_dict(self):
117
+ """
118
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
119
+ the token values by removing their value.
120
+ """
121
+ d = asdict(self)
122
+ for k, v in d.items():
123
+ if isinstance(v, Enum):
124
+ d[k] = v.value
125
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
126
+ d[k] = [x.value for x in v]
127
+ if k.endswith("_token"):
128
+ d[k] = f"<{k.upper()}>"
129
+ return d
130
+
131
+
132
+ @dataclass
133
+ class ModelArguments:
134
+ """
135
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
136
+ """
137
+
138
+ model_name_or_path: Optional[str] = field(
139
+ default=None,
140
+ metadata={
141
+ "help": (
142
+ "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
143
+ )
144
+ },
145
+ )
146
+ model_type: Optional[str] = field(
147
+ default=None,
148
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
149
+ )
150
+ config_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
152
+ )
153
+ tokenizer_name: Optional[str] = field(
154
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
155
+ )
156
+ cache_dir: Optional[str] = field(
157
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
158
+ )
159
+ use_fast_tokenizer: bool = field(
160
+ default=True,
161
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
162
+ )
163
+ dtype: Optional[str] = field(
164
+ default="float32",
165
+ metadata={
166
+ "help": (
167
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
168
+ " `[float32, float16, bfloat16]`."
169
+ )
170
+ },
171
+ )
172
+ token: str = field(
173
+ default=None,
174
+ metadata={
175
+ "help": (
176
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
177
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
178
+ )
179
+ },
180
+ )
181
+ use_auth_token: bool = field(
182
+ default=None,
183
+ metadata={
184
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
185
+ },
186
+ )
187
+
188
+
189
+ @dataclass
190
+ class DataTrainingArguments:
191
+ """
192
+ Arguments pertaining to what data we are going to input our model for training and eval.
193
+ """
194
+
195
+ dataset_name: Optional[str] = field(
196
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
197
+ )
198
+ dataset_config_name: Optional[str] = field(
199
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
200
+ )
201
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
202
+ validation_file: Optional[str] = field(
203
+ default=None,
204
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
205
+ )
206
+ train_ref_file: Optional[str] = field(
207
+ default=None,
208
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
209
+ )
210
+ validation_ref_file: Optional[str] = field(
211
+ default=None,
212
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
213
+ )
214
+ overwrite_cache: bool = field(
215
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
216
+ )
217
+ validation_split_percentage: Optional[int] = field(
218
+ default=5,
219
+ metadata={
220
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
221
+ },
222
+ )
223
+ max_seq_length: Optional[int] = field(
224
+ default=None,
225
+ metadata={
226
+ "help": (
227
+ "The maximum total input sequence length after tokenization and masking. Sequences longer than this"
228
+ " will be truncated. Default to the max input length of the model."
229
+ )
230
+ },
231
+ )
232
+ preprocessing_num_workers: Optional[int] = field(
233
+ default=None,
234
+ metadata={"help": "The number of processes to use for the preprocessing."},
235
+ )
236
+ mlm_probability: float = field(
237
+ default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
238
+ )
239
+ mean_noise_span_length: float = field(
240
+ default=3.0,
241
+ metadata={"help": "Mean span length of masked tokens"},
242
+ )
243
+
244
+ def __post_init__(self):
245
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
246
+ raise ValueError("Need either a dataset name or a training/validation file.")
247
+ else:
248
+ if self.train_file is not None:
249
+ extension = self.train_file.split(".")[-1]
250
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
251
+ if self.validation_file is not None:
252
+ extension = self.validation_file.split(".")[-1]
253
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
254
+
255
+
256
+ def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
257
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
258
+
259
+ Training parameters to avoid padding with random_spans_noise_mask.
260
+ When training a model with random_spans_noise_mask, we would like to set the other
261
+ training hyperparmeters in a way that avoids padding.
262
+ This function helps us compute these hyperparameters.
263
+ We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
264
+ and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
265
+ This function tells us the required number of tokens in the raw example (for split_tokens())
266
+ as well as the length of the encoded targets. Note that this function assumes
267
+ the inputs and targets will have EOS appended and includes that in the reported length.
268
+
269
+ Args:
270
+ inputs_length: an integer - desired length of the tokenized inputs sequence
271
+ noise_density: a float
272
+ mean_noise_span_length: a float
273
+ Returns:
274
+ tokens_length: length of original text in tokens
275
+ targets_length: an integer - length in tokens of encoded targets sequence
276
+ """
277
+
278
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
279
+ num_noise_tokens = int(round(tokens_length * noise_density))
280
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
281
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
282
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
283
+ # and one EOS token.
284
+ _input_length = num_nonnoise_tokens + num_noise_spans + 1
285
+ _output_length = num_noise_tokens + num_noise_spans + 1
286
+ return _input_length, _output_length
287
+
288
+ tokens_length = inputs_length
289
+
290
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
291
+ tokens_length += 1
292
+
293
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
294
+
295
+ # minor hack to get the targets length to be equal to inputs length
296
+ # which is more likely to have been set to a nice round number.
297
+ if noise_density == 0.5 and targets_length > inputs_length:
298
+ tokens_length -= 1
299
+ targets_length -= 1
300
+ return tokens_length, targets_length
301
+
302
+
303
+ @flax.struct.dataclass
304
+ class FlaxDataCollatorForT5MLM:
305
+ """
306
+ Data collator used for T5 span-masked language modeling.
307
+ It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
308
+ For more information on how T5 span-masked language modeling works, one can take a look
309
+ at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
310
+ or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
311
+
312
+ Args:
313
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
314
+ The tokenizer used for encoding the data.
315
+ noise_density (:obj:`float`):
316
+ The probability with which to (randomly) mask tokens in the input.
317
+ mean_noise_span_length (:obj:`float`):
318
+ The average span length of the masked tokens.
319
+ input_length (:obj:`int`):
320
+ The expected input length after masking.
321
+ target_length (:obj:`int`):
322
+ The expected target length after masking.
323
+ pad_token_id: (:obj:`int`):
324
+ The pad token id of the model
325
+ decoder_start_token_id: (:obj:`int):
326
+ The decoder start token id of the model
327
+ """
328
+
329
+ tokenizer: PreTrainedTokenizerBase
330
+ noise_density: float
331
+ mean_noise_span_length: float
332
+ input_length: int
333
+ target_length: int
334
+ pad_token_id: int
335
+ decoder_start_token_id: int
336
+
337
+ def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
338
+ # convert list to dict and tensorize input
339
+ batch = BatchEncoding(
340
+ {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
341
+ )
342
+
343
+ input_ids = batch["input_ids"]
344
+ batch_size, expandend_input_length = input_ids.shape
345
+
346
+ mask_indices = np.asarray([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
347
+ labels_mask = ~mask_indices
348
+
349
+ input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
350
+ labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
351
+
352
+ batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
353
+ batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
354
+
355
+ if batch["input_ids"].shape[-1] != self.input_length:
356
+ raise ValueError(
357
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
358
+ f" should be {self.input_length}."
359
+ )
360
+
361
+ if batch["labels"].shape[-1] != self.target_length:
362
+ raise ValueError(
363
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
364
+ f" {self.target_length}."
365
+ )
366
+
367
+ # to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
368
+ batch["decoder_input_ids"] = shift_tokens_right(
369
+ batch["labels"], self.pad_token_id, self.decoder_start_token_id
370
+ )
371
+
372
+ return batch
373
+
374
+ def create_sentinel_ids(self, mask_indices):
375
+ """
376
+ Sentinel ids creation given the indices that should be masked.
377
+ The start indices of each mask are replaced by the sentinel ids in increasing
378
+ order. Consecutive mask indices to be deleted are replaced with `-1`.
379
+ """
380
+ start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
381
+ start_indices[:, 0] = mask_indices[:, 0]
382
+
383
+ sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
384
+ sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
385
+ sentinel_ids -= mask_indices - start_indices
386
+
387
+ return sentinel_ids
388
+
389
+ def filter_input_ids(self, input_ids, sentinel_ids):
390
+ """
391
+ Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
392
+ This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
393
+ """
394
+ batch_size = input_ids.shape[0]
395
+
396
+ input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
397
+ # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
398
+ # masked tokens coming after sentinel tokens and should be removed
399
+ input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
400
+ input_ids = np.concatenate(
401
+ [input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
402
+ )
403
+ return input_ids
404
+
405
+ def random_spans_noise_mask(self, length):
406
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
407
+
408
+ Noise mask consisting of random spans of noise tokens.
409
+ The number of noise tokens and the number of noise spans and non-noise spans
410
+ are determined deterministically as follows:
411
+ num_noise_tokens = round(length * noise_density)
412
+ num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
413
+ Spans alternate between non-noise and noise, beginning with non-noise.
414
+ Subject to the above restrictions, all masks are equally likely.
415
+
416
+ Args:
417
+ length: an int32 scalar (length of the incoming token sequence)
418
+ noise_density: a float - approximate density of output mask
419
+ mean_noise_span_length: a number
420
+
421
+ Returns:
422
+ a boolean tensor with shape [length]
423
+ """
424
+
425
+ orig_length = length
426
+
427
+ num_noise_tokens = int(np.round(length * self.noise_density))
428
+ num_nonnoise_tokens = length - num_noise_tokens
429
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
430
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
431
+ # num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens
432
+ num_noise_spans = int(np.round(min(num_noise_tokens, num_nonnoise_tokens) / self.mean_noise_span_length))
433
+
434
+ # avoid degeneracy by ensuring positive number of noise spans
435
+ num_noise_spans = max(num_noise_spans, 1)
436
+
437
+ # pick the lengths of the noise spans and the non-noise spans
438
+ def _random_segmentation(num_items, num_segments):
439
+ """Partition a sequence of items randomly into non-empty segments.
440
+ Args:
441
+ num_items: an integer scalar > 0
442
+ num_segments: an integer scalar in [1, num_items]
443
+ Returns:
444
+ a Tensor with shape [num_segments] containing positive integers that add
445
+ up to num_items
446
+ """
447
+ mask_indices = np.arange(num_items - 1) < (num_segments - 1)
448
+ np.random.shuffle(mask_indices)
449
+ first_in_segment = np.pad(mask_indices, [[1, 0]])
450
+ segment_id = np.cumsum(first_in_segment)
451
+ # count length of sub segments assuming that list is sorted
452
+ _, segment_length = np.unique(segment_id, return_counts=True)
453
+ return segment_length
454
+
455
+ noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
456
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
457
+
458
+ interleaved_span_lengths = np.reshape(
459
+ np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
460
+ )
461
+ span_starts = np.cumsum(interleaved_span_lengths)[:-1]
462
+ span_start_indicator = np.zeros((length,), dtype=np.int8)
463
+ span_start_indicator[span_starts] = True
464
+ span_num = np.cumsum(span_start_indicator)
465
+ is_noise = np.equal(span_num % 2, 1)
466
+
467
+ return is_noise[:orig_length]
468
+
469
+
470
+ def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
471
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
472
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
473
+ num_samples = len(samples_idx)
474
+ if drop_last:
475
+ samples_to_remove = num_samples % batch_size
476
+ if samples_to_remove != 0:
477
+ samples_idx = samples_idx[:-samples_to_remove]
478
+ sections_split = num_samples // batch_size
479
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
480
+ else:
481
+ sections_split = math.ceil(num_samples / batch_size)
482
+ samples_idx = np.array_split(samples_idx, sections_split)
483
+ return samples_idx
484
+
485
+
486
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
487
+ summary_writer.scalar("train_time", train_time, step)
488
+
489
+ train_metrics = get_metrics(train_metrics)
490
+ for key, vals in train_metrics.items():
491
+ tag = f"train_{key}"
492
+ for i, val in enumerate(vals):
493
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
494
+
495
+
496
+ def write_eval_metric(summary_writer, eval_metrics, step):
497
+ for metric_name, value in eval_metrics.items():
498
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
499
+
500
+
501
+ def main():
502
+ # See all possible arguments in src/transformers/training_args.py
503
+ # or by passing the --help flag to this script.
504
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
505
+
506
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
507
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
508
+ # If we pass only one argument to the script and it's the path to a json file,
509
+ # let's parse it to get our arguments.
510
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
511
+ else:
512
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
513
+
514
+ if model_args.use_auth_token is not None:
515
+ warnings.warn(
516
+ "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
517
+ FutureWarning,
518
+ )
519
+ if model_args.token is not None:
520
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
521
+ model_args.token = model_args.use_auth_token
522
+
523
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
524
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
525
+ send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax")
526
+
527
+ if (
528
+ os.path.exists(training_args.output_dir)
529
+ and os.listdir(training_args.output_dir)
530
+ and training_args.do_train
531
+ and not training_args.overwrite_output_dir
532
+ ):
533
+ raise ValueError(
534
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
535
+ "Use --overwrite_output_dir to overcome."
536
+ )
537
+
538
+ # Setup logging
539
+ logging.basicConfig(
540
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
541
+ level=logging.INFO,
542
+ datefmt="[%X]",
543
+ )
544
+
545
+ # Log on each process the small summary:
546
+ logger = logging.getLogger(__name__)
547
+
548
+ # Set the verbosity to info of the Transformers logger (on main process only):
549
+ logger.info(f"Training/evaluation parameters {training_args}")
550
+
551
+ # Set seed before initializing model.
552
+ set_seed(training_args.seed)
553
+
554
+ # Handle the repository creation
555
+ if training_args.push_to_hub:
556
+ # Retrieve of infer repo_name
557
+ repo_name = training_args.hub_model_id
558
+ if repo_name is None:
559
+ repo_name = Path(training_args.output_dir).absolute().name
560
+ # Create repo and retrieve repo_id
561
+ repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
562
+ # Clone repo locally
563
+ repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
564
+
565
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
566
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
567
+ # (the dataset will be downloaded automatically from the datasets Hub).
568
+ #
569
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
570
+ # 'text' is found. You can easily tweak this behavior (see below).
571
+ if data_args.dataset_name is not None:
572
+ # Downloading and loading a dataset from the hub.
573
+ datasets = load_dataset(
574
+ data_args.dataset_name,
575
+ data_args.dataset_config_name,
576
+ cache_dir=model_args.cache_dir,
577
+ token=model_args.token,
578
+ num_proc=data_args.preprocessing_num_workers,
579
+ )
580
+
581
+ if "validation" not in datasets.keys():
582
+ datasets["validation"] = load_dataset(
583
+ data_args.dataset_name,
584
+ data_args.dataset_config_name,
585
+ split=f"train[:{data_args.validation_split_percentage}%]",
586
+ cache_dir=model_args.cache_dir,
587
+ token=model_args.token,
588
+ num_proc=data_args.preprocessing_num_workers,
589
+ )
590
+ datasets["train"] = load_dataset(
591
+ data_args.dataset_name,
592
+ data_args.dataset_config_name,
593
+ split=f"train[{data_args.validation_split_percentage}%:]",
594
+ cache_dir=model_args.cache_dir,
595
+ token=model_args.token,
596
+ num_proc=data_args.preprocessing_num_workers,
597
+ )
598
+ else:
599
+ data_files = {}
600
+ if data_args.train_file is not None:
601
+ data_files["train"] = data_args.train_file
602
+ extension = data_args.train_file.split(".")[-1]
603
+ if data_args.validation_file is not None:
604
+ data_files["validation"] = data_args.validation_file
605
+ extension = data_args.validation_file.split(".")[-1]
606
+ if extension == "txt":
607
+ extension = "text"
608
+ datasets = load_dataset(
609
+ extension,
610
+ data_files=data_files,
611
+ cache_dir=model_args.cache_dir,
612
+ token=model_args.token,
613
+ num_proc=data_args.preprocessing_num_workers,
614
+ )
615
+
616
+ if "validation" not in datasets.keys():
617
+ datasets["validation"] = load_dataset(
618
+ extension,
619
+ data_files=data_files,
620
+ split=f"train[:{data_args.validation_split_percentage}%]",
621
+ cache_dir=model_args.cache_dir,
622
+ token=model_args.token,
623
+ num_proc=data_args.preprocessing_num_workers,
624
+ )
625
+ datasets["train"] = load_dataset(
626
+ extension,
627
+ data_files=data_files,
628
+ split=f"train[{data_args.validation_split_percentage}%:]",
629
+ cache_dir=model_args.cache_dir,
630
+ token=model_args.token,
631
+ num_proc=data_args.preprocessing_num_workers,
632
+ )
633
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
634
+ # https://huggingface.co/docs/datasets/loading_datasets.
635
+
636
+ # Load pretrained model and tokenizer
637
+
638
+ if model_args.tokenizer_name:
639
+ tokenizer = AutoTokenizer.from_pretrained(
640
+ model_args.tokenizer_name,
641
+ cache_dir=model_args.cache_dir,
642
+ use_fast=model_args.use_fast_tokenizer,
643
+ token=model_args.token,
644
+ )
645
+ elif model_args.model_name_or_path:
646
+ tokenizer = AutoTokenizer.from_pretrained(
647
+ model_args.model_name_or_path,
648
+ cache_dir=model_args.cache_dir,
649
+ use_fast=model_args.use_fast_tokenizer,
650
+ token=model_args.token,
651
+ )
652
+ else:
653
+ raise ValueError(
654
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
655
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
656
+ )
657
+
658
+ if model_args.config_name:
659
+ config = T5Config.from_pretrained(
660
+ model_args.config_name,
661
+ cache_dir=model_args.cache_dir,
662
+ vocab_size=len(tokenizer),
663
+ token=model_args.token,
664
+ )
665
+ elif model_args.model_name_or_path:
666
+ config = T5Config.from_pretrained(
667
+ model_args.model_name_or_path,
668
+ cache_dir=model_args.cache_dir,
669
+ token=model_args.token,
670
+ )
671
+ else:
672
+ config = CONFIG_MAPPING[model_args.model_type]()
673
+ logger.warning("You are instantiating a new config instance from scratch.")
674
+
675
+ # Preprocessing the datasets.
676
+ # First we tokenize all the texts.
677
+ if training_args.do_train:
678
+ column_names = datasets["train"].column_names
679
+ else:
680
+ column_names = datasets["validation"].column_names
681
+ text_column_name = "text" if "text" in column_names else column_names[0]
682
+
683
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
684
+
685
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
686
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
687
+ def tokenize_function(examples):
688
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
689
+
690
+ tokenized_datasets = datasets.map(
691
+ tokenize_function,
692
+ batched=True,
693
+ num_proc=data_args.preprocessing_num_workers,
694
+ remove_columns=column_names,
695
+ load_from_cache_file=not data_args.overwrite_cache,
696
+ )
697
+
698
+ # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
699
+ # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
700
+ # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
701
+ expanded_inputs_length, targets_length = compute_input_and_target_lengths(
702
+ inputs_length=max_seq_length,
703
+ noise_density=data_args.mlm_probability,
704
+ mean_noise_span_length=data_args.mean_noise_span_length,
705
+ )
706
+
707
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
708
+ def group_texts(examples):
709
+ # Concatenate all texts.
710
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
711
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
712
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
713
+ # customize this part to your needs.
714
+ if total_length >= expanded_inputs_length:
715
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
716
+ # Split by chunks of max_len.
717
+ result = {
718
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
719
+ for k, t in concatenated_examples.items()
720
+ }
721
+ return result
722
+
723
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
724
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
725
+ # might be slower to preprocess.
726
+ #
727
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
728
+ # https://huggingface.co/docs/datasets/process#map
729
+ tokenized_datasets = tokenized_datasets.map(
730
+ group_texts,
731
+ batched=True,
732
+ num_proc=data_args.preprocessing_num_workers,
733
+ load_from_cache_file=not data_args.overwrite_cache,
734
+ )
735
+
736
+ # Enable tensorboard only on the master node
737
+ has_tensorboard = is_tensorboard_available()
738
+ if has_tensorboard and jax.process_index() == 0:
739
+ try:
740
+ from flax.metrics.tensorboard import SummaryWriter
741
+
742
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
743
+ except ImportError as ie:
744
+ has_tensorboard = False
745
+ logger.warning(
746
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
747
+ )
748
+ else:
749
+ logger.warning(
750
+ "Unable to display metrics through TensorBoard because the package is not installed: "
751
+ "Please run pip install tensorboard to enable."
752
+ )
753
+
754
+ # Initialize our training
755
+ rng = jax.random.PRNGKey(training_args.seed)
756
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
757
+
758
+ if model_args.model_name_or_path:
759
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
760
+ model_args.model_name_or_path,
761
+ config=config,
762
+ seed=training_args.seed,
763
+ dtype=getattr(jnp, model_args.dtype),
764
+ token=model_args.token,
765
+ )
766
+ else:
767
+ config.vocab_size = len(tokenizer)
768
+ model = FlaxT5ForConditionalGeneration(
769
+ config,
770
+ seed=training_args.seed,
771
+ dtype=getattr(jnp, model_args.dtype),
772
+ )
773
+
774
+ # Data collator
775
+ # This one will take care of randomly masking the tokens.
776
+ data_collator = FlaxDataCollatorForT5MLM(
777
+ tokenizer=tokenizer,
778
+ noise_density=data_args.mlm_probability,
779
+ mean_noise_span_length=data_args.mean_noise_span_length,
780
+ input_length=max_seq_length,
781
+ target_length=targets_length,
782
+ pad_token_id=model.config.pad_token_id,
783
+ decoder_start_token_id=model.config.decoder_start_token_id,
784
+ )
785
+
786
+ # Store some constant
787
+ num_epochs = int(training_args.num_train_epochs)
788
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
789
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
790
+ eval_batch_size = per_device_eval_batch_size * jax.device_count()
791
+
792
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
793
+
794
+ num_of_hosts = jax.process_count()
795
+ current_host_idx = jax.process_index()
796
+
797
+ # Create learning rate schedule
798
+ warmup_fn = optax.linear_schedule(
799
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
800
+ )
801
+ decay_fn = optax.linear_schedule(
802
+ init_value=training_args.learning_rate,
803
+ end_value=0,
804
+ transition_steps=num_train_steps - training_args.warmup_steps,
805
+ )
806
+ linear_decay_lr_schedule_fn = optax.join_schedules(
807
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
808
+ )
809
+
810
+ # We use Optax's "masking" functionality to not apply weight decay
811
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
812
+ # mask boolean with the same structure as the parameters.
813
+ # The mask is True for parameters that should be decayed.
814
+ def decay_mask_fn(params):
815
+ flat_params = traverse_util.flatten_dict(params)
816
+ # find out all LayerNorm parameters
817
+ layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
818
+ layer_norm_named_params = {
819
+ layer[-2:]
820
+ for layer_norm_name in layer_norm_candidates
821
+ for layer in flat_params.keys()
822
+ if layer_norm_name in "".join(layer).lower()
823
+ }
824
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
825
+ return traverse_util.unflatten_dict(flat_mask)
826
+
827
+ # create adam optimizer
828
+ if training_args.adafactor:
829
+ # We use the default parameters here to initialize adafactor,
830
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
831
+ optimizer = optax.adafactor(
832
+ learning_rate=linear_decay_lr_schedule_fn,
833
+ )
834
+ else:
835
+ optimizer = optax.adamw(
836
+ learning_rate=linear_decay_lr_schedule_fn,
837
+ b1=training_args.adam_beta1,
838
+ b2=training_args.adam_beta2,
839
+ weight_decay=training_args.weight_decay,
840
+ mask=decay_mask_fn,
841
+ )
842
+
843
+ # Setup train state
844
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
845
+
846
+ # Define gradient update step fn
847
+ def train_step(state, batch, dropout_rng):
848
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
849
+
850
+ def loss_fn(params):
851
+ labels = batch.pop("labels")
852
+
853
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
854
+
855
+ # compute loss
856
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
857
+
858
+ return loss
859
+
860
+ grad_fn = jax.value_and_grad(loss_fn)
861
+ loss, grad = grad_fn(state.params)
862
+ grad = jax.lax.pmean(grad, "batch")
863
+ new_state = state.apply_gradients(grads=grad)
864
+
865
+ metrics = jax.lax.pmean(
866
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
867
+ )
868
+
869
+ return new_state, metrics, new_dropout_rng
870
+
871
+ # Create parallel version of the train step
872
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
873
+
874
+ # Define eval fn
875
+ def eval_step(params, batch):
876
+ labels = batch.pop("labels")
877
+
878
+ logits = model(**batch, params=params, train=False)[0]
879
+
880
+ # compute loss
881
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
882
+
883
+ # compute accuracy
884
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels)
885
+
886
+ # summarize metrics
887
+ metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()}
888
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
889
+
890
+ return metrics
891
+
892
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
893
+
894
+ # Replicate the train state on each device
895
+ state = jax_utils.replicate(state)
896
+
897
+ train_time = 0
898
+ epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
899
+ for epoch in epochs:
900
+ # ======================== Training ================================
901
+ train_start = time.time()
902
+ train_metrics = []
903
+
904
+ # Create sampling rng
905
+ rng, input_rng = jax.random.split(rng)
906
+
907
+ # Generate an epoch by shuffling sampling indices from the train dataset
908
+ num_train_samples = len(tokenized_datasets["train"])
909
+ # Avoid using jax.numpy here in case of TPU training
910
+ train_samples_idx = np.random.permutation(np.arange(num_train_samples))
911
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
912
+
913
+ # Gather the indexes for creating the batch and do a training step
914
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
915
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
916
+ model_inputs = data_collator(samples)
917
+
918
+ local_host_model_inputs = {
919
+ key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx]
920
+ for key, value in model_inputs.data.items()
921
+ }
922
+
923
+ # Model forward
924
+ model_inputs = shard(local_host_model_inputs)
925
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
926
+ train_metrics.append(train_metric)
927
+
928
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
929
+
930
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
931
+ # Save metrics
932
+ train_metric = jax_utils.unreplicate(train_metric)
933
+ train_time += time.time() - train_start
934
+ if has_tensorboard and jax.process_index() == 0:
935
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
936
+
937
+ epochs.write(
938
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
939
+ f" {train_metric['learning_rate'].mean()})"
940
+ )
941
+
942
+ train_metrics = []
943
+
944
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
945
+ # ======================== Evaluating ==============================
946
+ num_eval_samples = len(tokenized_datasets["validation"])
947
+ # Avoid using jax.numpy here in case of TPU training
948
+ eval_samples_idx = np.arange(num_eval_samples)
949
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
950
+
951
+ eval_metrics = []
952
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
953
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
954
+ model_inputs = data_collator(samples)
955
+
956
+ # Model forward
957
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
958
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
959
+ )
960
+ eval_metrics.append(metrics)
961
+
962
+ # get eval metrics
963
+ eval_metrics = get_metrics(eval_metrics)
964
+ eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
965
+
966
+ # Update progress bar
967
+ epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
968
+
969
+ # Save metrics
970
+ if has_tensorboard and jax.process_index() == 0:
971
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
972
+
973
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
974
+ # save checkpoint after each epoch and push checkpoint to the hub
975
+ if jax.process_index() == 0:
976
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
977
+ model.save_pretrained(training_args.output_dir, params=params)
978
+ tokenizer.save_pretrained(training_args.output_dir)
979
+ if training_args.push_to_hub:
980
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
981
+
982
+ # Eval after training
983
+ if training_args.do_eval:
984
+ num_eval_samples = len(tokenized_datasets["validation"])
985
+ # Avoid using jax.numpy here in case of TPU training
986
+ eval_samples_idx = np.arange(num_eval_samples)
987
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
988
+
989
+ eval_metrics = []
990
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
991
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
992
+ model_inputs = data_collator(samples)
993
+
994
+ # Model forward
995
+ metrics = pad_shard_unpad(p_eval_step, static_return=True)(
996
+ state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
997
+ )
998
+ eval_metrics.append(metrics)
999
+
1000
+ # get eval metrics
1001
+ eval_metrics = get_metrics(eval_metrics)
1002
+ eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
1003
+
1004
+ if jax.process_index() == 0:
1005
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
1006
+ path = os.path.join(training_args.output_dir, "eval_results.json")
1007
+ with open(path, "w") as f:
1008
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
1009
+
1010
+
1011
+ if __name__ == "__main__":
1012
+ main()