Hooman Sedghamiz commited on
Commit
01ae861
1 Parent(s): 3a36a39

pushing a template clm training script for gpt2

Browse files
Files changed (1) hide show
  1. src/run_clm_flax.py +625 -0
src/run_clm_flax.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=causal-lm
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Callable, Optional
32
+
33
+ import datasets
34
+ from datasets import Dataset, load_dataset
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ from flax import jax_utils, traverse_util
42
+ from flax.jax_utils import unreplicate
43
+ from flax.training import train_state
44
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
+ from transformers import (
46
+ CONFIG_MAPPING,
47
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
48
+ AutoConfig,
49
+ AutoTokenizer,
50
+ FlaxAutoModelForCausalLM,
51
+ HfArgumentParser,
52
+ TrainingArguments,
53
+ is_tensorboard_available,
54
+ )
55
+ from transformers.testing_utils import CaptureLogger
56
+
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+ # Cache the result
61
+ has_tensorboard = is_tensorboard_available()
62
+ if has_tensorboard:
63
+ try:
64
+ from flax.metrics.tensorboard import SummaryWriter
65
+ except ImportError as ie:
66
+ has_tensorboard = False
67
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
68
+
69
+ else:
70
+ print(
71
+ "Unable to display metrics through TensorBoard because the package is not installed: "
72
+ "Please run pip install tensorboard to enable."
73
+ )
74
+
75
+
76
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
77
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
78
+
79
+
80
+ @dataclass
81
+ class ModelArguments:
82
+ """
83
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
84
+ """
85
+
86
+ model_name_or_path: Optional[str] = field(
87
+ default=None,
88
+ metadata={
89
+ "help": "The model checkpoint for weights initialization."
90
+ "Don't set if you want to train a model from scratch."
91
+ },
92
+ )
93
+ model_type: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
96
+ )
97
+ config_name: Optional[str] = field(
98
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
99
+ )
100
+ tokenizer_name: Optional[str] = field(
101
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
102
+ )
103
+ cache_dir: Optional[str] = field(
104
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
105
+ )
106
+ use_fast_tokenizer: bool = field(
107
+ default=True,
108
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
109
+ )
110
+ dtype: Optional[str] = field(
111
+ default="float32",
112
+ metadata={
113
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
114
+ },
115
+ )
116
+
117
+
118
+ @dataclass
119
+ class DataTrainingArguments:
120
+ """
121
+ Arguments pertaining to what data we are going to input our model for training and eval.
122
+ """
123
+
124
+ dataset_name: Optional[str] = field(
125
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
126
+ )
127
+ dataset_config_name: Optional[str] = field(
128
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
129
+ )
130
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
131
+ validation_file: Optional[str] = field(
132
+ default=None,
133
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
134
+ )
135
+ max_train_samples: Optional[int] = field(
136
+ default=None,
137
+ metadata={
138
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
139
+ "value if set."
140
+ },
141
+ )
142
+ max_eval_samples: Optional[int] = field(
143
+ default=None,
144
+ metadata={
145
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
146
+ "value if set."
147
+ },
148
+ )
149
+ overwrite_cache: bool = field(
150
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
151
+ )
152
+ validation_split_percentage: Optional[int] = field(
153
+ default=5,
154
+ metadata={
155
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
156
+ },
157
+ )
158
+ block_size: Optional[int] = field(
159
+ default=None,
160
+ metadata={
161
+ "help": "Optional input sequence length after tokenization. "
162
+ "The training dataset will be truncated in block of this size for training. "
163
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
164
+ },
165
+ )
166
+ overwrite_cache: bool = field(
167
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
168
+ )
169
+ preprocessing_num_workers: Optional[int] = field(
170
+ default=None,
171
+ metadata={"help": "The number of processes to use for the preprocessing."},
172
+ )
173
+
174
+ def __post_init__(self):
175
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
176
+ raise ValueError("Need either a dataset name or a training/validation file.")
177
+ else:
178
+ if self.train_file is not None:
179
+ extension = self.train_file.split(".")[-1]
180
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
181
+ if self.validation_file is not None:
182
+ extension = self.validation_file.split(".")[-1]
183
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
184
+
185
+
186
+ class TrainState(train_state.TrainState):
187
+ dropout_rng: jnp.ndarray
188
+
189
+ def replicate(self):
190
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
191
+
192
+
193
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
194
+ """
195
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
196
+ Shuffle batches if `shuffle` is `True`.
197
+ """
198
+ steps_per_epoch = len(dataset) // batch_size
199
+
200
+ if shuffle:
201
+ batch_idx = jax.random.permutation(rng, len(dataset))
202
+ else:
203
+ batch_idx = jnp.arange(len(dataset))
204
+
205
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
206
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
207
+
208
+ for idx in batch_idx:
209
+ batch = dataset[idx]
210
+ batch = {k: jnp.array(v) for k, v in batch.items()}
211
+
212
+ batch = shard(batch)
213
+
214
+ yield batch
215
+
216
+
217
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
218
+ summary_writer.scalar("train_time", train_time, step)
219
+
220
+ train_metrics = get_metrics(train_metrics)
221
+ for key, vals in train_metrics.items():
222
+ tag = f"train_{key}"
223
+ for i, val in enumerate(vals):
224
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
225
+
226
+ for metric_name, value in eval_metrics.items():
227
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
228
+
229
+
230
+ def create_learning_rate_fn(
231
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
232
+ ) -> Callable[[int], jnp.array]:
233
+ """Returns a linear warmup, linear_decay learning rate function."""
234
+ steps_per_epoch = train_ds_size // train_batch_size
235
+ num_train_steps = steps_per_epoch * num_train_epochs
236
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
237
+ decay_fn = optax.linear_schedule(
238
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
239
+ )
240
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
241
+ return schedule_fn
242
+
243
+
244
+ def main():
245
+ # See all possible arguments in src/transformers/training_args.py
246
+ # or by passing the --help flag to this script.
247
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
248
+
249
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
250
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
251
+ # If we pass only one argument to the script and it's the path to a json file,
252
+ # let's parse it to get our arguments.
253
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
254
+ else:
255
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
256
+
257
+ if (
258
+ os.path.exists(training_args.output_dir)
259
+ and os.listdir(training_args.output_dir)
260
+ and training_args.do_train
261
+ and not training_args.overwrite_output_dir
262
+ ):
263
+ raise ValueError(
264
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
265
+ "Use --overwrite_output_dir to overcome."
266
+ )
267
+
268
+ # Make one log on every process with the configuration for debugging.
269
+ logging.basicConfig(
270
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
271
+ datefmt="%m/%d/%Y %H:%M:%S",
272
+ level=logging.INFO,
273
+ )
274
+ # Setup logging, we only want one process per machine to log things on the screen.
275
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
276
+ if jax.process_index() == 0:
277
+ datasets.utils.logging.set_verbosity_warning()
278
+ transformers.utils.logging.set_verbosity_info()
279
+ else:
280
+ datasets.utils.logging.set_verbosity_error()
281
+ transformers.utils.logging.set_verbosity_error()
282
+
283
+ # Set the verbosity to info of the Transformers logger (on main process only):
284
+ logger.info(f"Training/evaluation parameters {training_args}")
285
+
286
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
287
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
288
+ # (the dataset will be downloaded automatically from the datasets Hub).
289
+ #
290
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
291
+ # 'text' is found. You can easily tweak this behavior (see below).
292
+ #
293
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
294
+ # download the dataset.
295
+ if data_args.dataset_name is not None:
296
+ # Downloading and loading a dataset from the hub.
297
+ dataset = load_dataset(
298
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
299
+ )
300
+
301
+ if "validation" not in dataset.keys():
302
+ dataset["validation"] = load_dataset(
303
+ data_args.dataset_name,
304
+ data_args.dataset_config_name,
305
+ split=f"train[:{data_args.validation_split_percentage}%]",
306
+ cache_dir=model_args.cache_dir,
307
+ )
308
+ dataset["train"] = load_dataset(
309
+ data_args.dataset_name,
310
+ data_args.dataset_config_name,
311
+ split=f"train[{data_args.validation_split_percentage}%:]",
312
+ cache_dir=model_args.cache_dir,
313
+ )
314
+ else:
315
+ data_files = {}
316
+ if data_args.train_file is not None:
317
+ data_files["train"] = data_args.train_file
318
+ if data_args.validation_file is not None:
319
+ data_files["validation"] = data_args.validation_file
320
+ extension = data_args.train_file.split(".")[-1]
321
+ if extension == "txt":
322
+ extension = "text"
323
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
324
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
325
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
326
+
327
+ # Load pretrained model and tokenizer
328
+
329
+ # Distributed training:
330
+ # The .from_pretrained methods guarantee that only one local process can concurrently
331
+ # download model & vocab.
332
+ if model_args.config_name:
333
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
334
+ elif model_args.model_name_or_path:
335
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
336
+ else:
337
+ config = CONFIG_MAPPING[model_args.model_type]()
338
+ logger.warning("You are instantiating a new config instance from scratch.")
339
+
340
+ if model_args.tokenizer_name:
341
+ tokenizer = AutoTokenizer.from_pretrained(
342
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
343
+ )
344
+ elif model_args.model_name_or_path:
345
+ tokenizer = AutoTokenizer.from_pretrained(
346
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
347
+ )
348
+ else:
349
+ raise ValueError(
350
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
351
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
352
+ )
353
+
354
+ if model_args.model_name_or_path:
355
+ model = FlaxAutoModelForCausalLM.from_pretrained(
356
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
357
+ )
358
+ else:
359
+ model = FlaxAutoModelForCausalLM.from_config(
360
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
361
+ )
362
+
363
+ # Preprocessing the datasets.
364
+ # First we tokenize all the texts.
365
+ if training_args.do_train:
366
+ column_names = dataset["train"].column_names
367
+ else:
368
+ column_names = dataset["validation"].column_names
369
+ text_column_name = "text" if "text" in column_names else column_names[0]
370
+
371
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
372
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
373
+
374
+ def tokenize_function(examples):
375
+ with CaptureLogger(tok_logger) as cl:
376
+ output = tokenizer(examples[text_column_name])
377
+ # clm input could be much much longer than block_size
378
+ if "Token indices sequence length is longer than the" in cl.out:
379
+ tok_logger.warning(
380
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
381
+ )
382
+ return output
383
+
384
+ tokenized_datasets = dataset.map(
385
+ tokenize_function,
386
+ batched=True,
387
+ num_proc=data_args.preprocessing_num_workers,
388
+ remove_columns=column_names,
389
+ load_from_cache_file=not data_args.overwrite_cache,
390
+ )
391
+
392
+ if data_args.block_size is None:
393
+ block_size = tokenizer.model_max_length
394
+ if block_size > config.max_position_embeddings:
395
+ logger.warning(
396
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
397
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
398
+ )
399
+ block_size = 1024
400
+ else:
401
+ if data_args.block_size > tokenizer.model_max_length:
402
+ logger.warning(
403
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
404
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
405
+ )
406
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
407
+
408
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
409
+ def group_texts(examples):
410
+ # Concatenate all texts.
411
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
412
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
413
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
414
+ # customize this part to your needs.
415
+ total_length = (total_length // block_size) * block_size
416
+ # Split by chunks of max_len.
417
+ result = {
418
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
419
+ for k, t in concatenated_examples.items()
420
+ }
421
+ result["labels"] = result["input_ids"].copy()
422
+ return result
423
+
424
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
425
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
426
+ # to preprocess.
427
+ #
428
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
429
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
430
+
431
+ lm_datasets = tokenized_datasets.map(
432
+ group_texts,
433
+ batched=True,
434
+ num_proc=data_args.preprocessing_num_workers,
435
+ load_from_cache_file=not data_args.overwrite_cache,
436
+ )
437
+
438
+ if training_args.do_train:
439
+ if "train" not in tokenized_datasets:
440
+ raise ValueError("--do_train requires a train dataset")
441
+ train_dataset = lm_datasets["train"]
442
+ if data_args.max_train_samples is not None:
443
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
444
+
445
+ if training_args.do_eval:
446
+ if "validation" not in tokenized_datasets:
447
+ raise ValueError("--do_eval requires a validation dataset")
448
+ eval_dataset = lm_datasets["validation"]
449
+ if data_args.max_eval_samples is not None:
450
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
451
+
452
+ # Enable tensorboard only on the master node
453
+ if has_tensorboard and jax.process_index() == 0:
454
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
455
+
456
+ # Initialize our training
457
+ rng = jax.random.PRNGKey(training_args.seed)
458
+ rng, dropout_rng = jax.random.split(rng)
459
+
460
+ # Store some constant
461
+ num_epochs = int(training_args.num_train_epochs)
462
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
463
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
464
+ steps_per_epoch = len(train_dataset) // train_batch_size
465
+ total_train_steps = steps_per_epoch * num_epochs
466
+
467
+ # Create learning rate schedule
468
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
469
+ len(train_dataset),
470
+ train_batch_size,
471
+ training_args.num_train_epochs,
472
+ training_args.warmup_steps,
473
+ training_args.learning_rate,
474
+ )
475
+
476
+ # We use Optax's "masking" functionality to not apply weight decay
477
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
478
+ # mask boolean with the same structure as the parameters.
479
+ # The mask is True for parameters that should be decayed.
480
+ # Note that this mask is specifically adapted for FlaxGPT2.
481
+ # For other models, one should correct the layer norm parameter naming
482
+ # accordingly.
483
+ def decay_mask_fn(params):
484
+ flat_params = traverse_util.flatten_dict(params)
485
+ flat_mask = {
486
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
487
+ for path in flat_params
488
+ }
489
+ return traverse_util.unflatten_dict(flat_mask)
490
+
491
+ # create adam optimizer
492
+ adamw = optax.adamw(
493
+ learning_rate=linear_decay_lr_schedule_fn,
494
+ b1=training_args.adam_beta1,
495
+ b2=training_args.adam_beta2,
496
+ eps=training_args.adam_epsilon,
497
+ weight_decay=training_args.weight_decay,
498
+ mask=decay_mask_fn,
499
+ )
500
+
501
+ # Setup train state
502
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
503
+
504
+ def loss_fn(logits, labels):
505
+ shift_logits = logits[..., :-1, :]
506
+ shift_labels = labels[..., 1:]
507
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
508
+ return loss.mean()
509
+
510
+ # Define gradient update step fn
511
+ def train_step(state, batch):
512
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
513
+
514
+ def compute_loss(params):
515
+ labels = batch.pop("labels")
516
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
517
+ loss = loss_fn(logits, labels)
518
+ return loss
519
+
520
+ grad_fn = jax.value_and_grad(compute_loss)
521
+ loss, grad = grad_fn(state.params)
522
+ grad = jax.lax.pmean(grad, "batch")
523
+
524
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
525
+
526
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
527
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
528
+
529
+ return new_state, metrics
530
+
531
+ # Define eval fn
532
+ def eval_step(params, batch):
533
+ labels = batch.pop("labels")
534
+ logits = model(**batch, params=params, train=False)[0]
535
+ loss = loss_fn(logits, labels)
536
+
537
+ # summarize metrics
538
+ metrics = {"loss": loss}
539
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
540
+ return metrics
541
+
542
+ # Create parallel version of the train and eval step
543
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
544
+ p_eval_step = jax.pmap(eval_step, "batch")
545
+
546
+ # Replicate the train state on each device
547
+ state = state.replicate()
548
+
549
+ logger.info("***** Running training *****")
550
+ logger.info(f" Num examples = {len(train_dataset)}")
551
+ logger.info(f" Num Epochs = {num_epochs}")
552
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
553
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
554
+ logger.info(f" Total optimization steps = {total_train_steps}")
555
+
556
+ train_time = 0
557
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
558
+ for epoch in epochs:
559
+ # ======================== Training ================================
560
+ train_start = time.time()
561
+
562
+ # Create sampling rng
563
+ rng, input_rng = jax.random.split(rng)
564
+ train_metrics = []
565
+
566
+ # Generate an epoch by shuffling sampling indices from the train dataset
567
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
568
+ steps_per_epoch = len(train_dataset) // train_batch_size
569
+ # train
570
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
571
+ batch = next(train_loader)
572
+ state, train_metric = p_train_step(state, batch)
573
+ train_metrics.append(train_metric)
574
+
575
+ train_time += time.time() - train_start
576
+
577
+ train_metric = unreplicate(train_metric)
578
+
579
+ epochs.write(
580
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
581
+ )
582
+
583
+ # ======================== Evaluating ==============================
584
+ eval_metrics = []
585
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
586
+ eval_steps = len(eval_dataset) // eval_batch_size
587
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
588
+ # Model forward
589
+ batch = next(eval_loader)
590
+ metrics = p_eval_step(state.params, batch)
591
+ eval_metrics.append(metrics)
592
+
593
+ # normalize eval metrics
594
+ eval_metrics = get_metrics(eval_metrics)
595
+
596
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
597
+
598
+ try:
599
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
600
+ except OverflowError:
601
+ eval_metrics["perplexity"] = float("inf")
602
+
603
+ # Print metrics and update progress bar
604
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
605
+ epochs.write(desc)
606
+ epochs.desc = desc
607
+
608
+ # Save metrics
609
+ if has_tensorboard and jax.process_index() == 0:
610
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
611
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
612
+
613
+ # save checkpoint after each epoch and push checkpoint to the hub
614
+ if jax.process_index() == 0:
615
+ params = jax.device_get(unreplicate(state.params))
616
+ model.save_pretrained(
617
+ training_args.output_dir,
618
+ params=params,
619
+ push_to_hub=training_args.push_to_hub,
620
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
621
+ )
622
+
623
+
624
+ if __name__ == "__main__":
625
+ main()