nbroad HF staff commited on
Commit
e3b3f1e
1 Parent(s): caafc9c

Upload run_summarization_flax.py

Browse files
Files changed (1) hide show
  1. run_summarization_flax.py +834 -0
run_summarization_flax.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for summarization.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ from importlib.metadata import metadata
22
+ import logging
23
+ import os
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass, field
27
+ from functools import partial
28
+ from pathlib import Path
29
+ from typing import Callable, Optional
30
+
31
+ import datasets
32
+ import nltk # Here to have a nice missing dependency error message early on
33
+ import numpy as np
34
+ from datasets import Dataset, load_dataset, load_metric
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ from filelock import FileLock
42
+ from flax import jax_utils, traverse_util
43
+ from flax.jax_utils import unreplicate
44
+ from flax.training import train_state
45
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
46
+ from huggingface_hub import Repository
47
+ from transformers import (
48
+ CONFIG_MAPPING,
49
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
50
+ AutoConfig,
51
+ AutoTokenizer,
52
+ FlaxAutoModelForSeq2SeqLM,
53
+ HfArgumentParser,
54
+ TrainingArguments,
55
+ is_tensorboard_available,
56
+ )
57
+ from transformers.file_utils import get_full_repo_name, is_offline_mode
58
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right as shift_tokens_right_fn
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+ try:
63
+ nltk.data.find("tokenizers/punkt")
64
+ except (LookupError, OSError):
65
+ if is_offline_mode():
66
+ raise LookupError(
67
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
68
+ )
69
+ with FileLock(".lock") as lock:
70
+ nltk.download("punkt", quiet=True)
71
+
72
+
73
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
74
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
81
+ """
82
+
83
+ model_name_or_path: Optional[str] = field(
84
+ default=None,
85
+ metadata={
86
+ "help": "The model checkpoint for weights initialization."
87
+ "Don't set if you want to train a model from scratch."
88
+ },
89
+ )
90
+ model_type: Optional[str] = field(
91
+ default=None,
92
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
93
+ )
94
+ config_name: Optional[str] = field(
95
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
96
+ )
97
+ tokenizer_name: Optional[str] = field(
98
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
99
+ )
100
+ cache_dir: Optional[str] = field(
101
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
102
+ )
103
+ use_fast_tokenizer: bool = field(
104
+ default=True,
105
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106
+ )
107
+ dtype: Optional[str] = field(
108
+ default="float32",
109
+ metadata={
110
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
111
+ },
112
+ )
113
+
114
+
115
+ @dataclass
116
+ class DataTrainingArguments:
117
+ """
118
+ Arguments pertaining to what data we are going to input our model for training and eval.
119
+ """
120
+
121
+ pretokenized: bool = field(
122
+ default=False, metadata={"help": "Set if the dataset is already tokenized."}
123
+ )
124
+ dataset_name: Optional[str] = field(
125
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
126
+ )
127
+ dataset_config_name: Optional[str] = field(
128
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
129
+ )
130
+ text_column: Optional[str] = field(
131
+ default=None,
132
+ metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
133
+ )
134
+ summary_column: Optional[str] = field(
135
+ default=None,
136
+ metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
137
+ )
138
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
139
+ validation_file: Optional[str] = field(
140
+ default=None,
141
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
142
+ )
143
+ test_file: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
146
+ )
147
+ max_source_length: Optional[int] = field(
148
+ default=1024,
149
+ metadata={
150
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
151
+ "than this will be truncated, sequences shorter will be padded."
152
+ },
153
+ )
154
+ max_target_length: Optional[int] = field(
155
+ default=128,
156
+ metadata={
157
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
158
+ "than this will be truncated, sequences shorter will be padded."
159
+ },
160
+ )
161
+ val_max_target_length: Optional[int] = field(
162
+ default=None,
163
+ metadata={
164
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
165
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
166
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
167
+ "during evaluation."
168
+ },
169
+ )
170
+ max_train_samples: Optional[int] = field(
171
+ default=None,
172
+ metadata={
173
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
174
+ "value if set."
175
+ },
176
+ )
177
+ max_eval_samples: Optional[int] = field(
178
+ default=None,
179
+ metadata={
180
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
181
+ "value if set."
182
+ },
183
+ )
184
+ max_predict_samples: Optional[int] = field(
185
+ default=None,
186
+ metadata={
187
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
188
+ "value if set."
189
+ },
190
+ )
191
+ preprocessing_num_workers: Optional[int] = field(
192
+ default=None,
193
+ metadata={"help": "The number of processes to use for the preprocessing."},
194
+ )
195
+ source_prefix: Optional[str] = field(
196
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
197
+ )
198
+ predict_with_generate: bool = field(
199
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
200
+ )
201
+ num_beams: Optional[int] = field(
202
+ default=None,
203
+ metadata={
204
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
205
+ "which is used during evaluation."
206
+ },
207
+ )
208
+ overwrite_cache: bool = field(
209
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
210
+ )
211
+
212
+ def __post_init__(self):
213
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
214
+ raise ValueError("Need either a dataset name or a training/validation file.")
215
+ else:
216
+ if self.train_file is not None:
217
+ extension = self.train_file.split(".")[-1]
218
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
219
+ if self.validation_file is not None:
220
+ extension = self.validation_file.split(".")[-1]
221
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
222
+ if self.val_max_target_length is None:
223
+ self.val_max_target_length = self.max_target_length
224
+
225
+
226
+ summarization_name_mapping = {
227
+ "amazon_reviews_multi": ("review_body", "review_title"),
228
+ "big_patent": ("description", "abstract"),
229
+ "cnn_dailymail": ("article", "highlights"),
230
+ "orange_sum": ("text", "summary"),
231
+ "pn_summary": ("article", "summary"),
232
+ "psc": ("extract_text", "summary_text"),
233
+ "samsum": ("dialogue", "summary"),
234
+ "thaisum": ("body", "summary"),
235
+ "xglue": ("news_body", "news_title"),
236
+ "xsum": ("document", "summary"),
237
+ "wiki_summary": ("article", "highlights"),
238
+ }
239
+
240
+
241
+ class TrainState(train_state.TrainState):
242
+ dropout_rng: jnp.ndarray
243
+
244
+ def replicate(self):
245
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
246
+
247
+
248
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
249
+ """
250
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
251
+ Shuffle batches if `shuffle` is `True`.
252
+ """
253
+ steps_per_epoch = len(dataset) // batch_size
254
+
255
+ if shuffle:
256
+ batch_idx = jax.random.permutation(rng, len(dataset))
257
+ else:
258
+ batch_idx = jnp.arange(len(dataset))
259
+
260
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
261
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
262
+
263
+ for idx in batch_idx:
264
+ batch = dataset[idx]
265
+ batch = {k: jnp.array(v) for k, v in batch.items()}
266
+
267
+ batch = shard(batch)
268
+
269
+ yield batch
270
+
271
+
272
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
273
+ summary_writer.scalar("train_time", train_time, step)
274
+
275
+ train_metrics = get_metrics(train_metrics)
276
+ for key, vals in train_metrics.items():
277
+ tag = f"train_{key}"
278
+ for i, val in enumerate(vals):
279
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
280
+
281
+ for metric_name, value in eval_metrics.items():
282
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
283
+
284
+
285
+ def create_learning_rate_fn(
286
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
287
+ ) -> Callable[[int], jnp.array]:
288
+ """Returns a linear warmup, linear_decay learning rate function."""
289
+ steps_per_epoch = train_ds_size // train_batch_size
290
+ num_train_steps = steps_per_epoch * num_train_epochs
291
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
292
+ decay_fn = optax.linear_schedule(
293
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
294
+ )
295
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
296
+ return schedule_fn
297
+
298
+
299
+ def main():
300
+ # See all possible arguments in src/transformers/training_args.py
301
+ # or by passing the --help flag to this script.
302
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
303
+
304
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
305
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
306
+ # If we pass only one argument to the script and it's the path to a json file,
307
+ # let's parse it to get our arguments.
308
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
309
+ else:
310
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
311
+
312
+ if (
313
+ os.path.exists(training_args.output_dir)
314
+ and os.listdir(training_args.output_dir)
315
+ and training_args.do_train
316
+ and not training_args.overwrite_output_dir
317
+ ):
318
+ raise ValueError(
319
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
320
+ "Use --overwrite_output_dir to overcome."
321
+ )
322
+
323
+ # Make one log on every process with the configuration for debugging.
324
+ logging.basicConfig(
325
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
326
+ datefmt="%m/%d/%Y %H:%M:%S",
327
+ level=logging.INFO,
328
+ )
329
+ # Setup logging, we only want one process per machine to log things on the screen.
330
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
331
+ if jax.process_index() == 0:
332
+ datasets.utils.logging.set_verbosity_warning()
333
+ transformers.utils.logging.set_verbosity_info()
334
+ else:
335
+ datasets.utils.logging.set_verbosity_error()
336
+ transformers.utils.logging.set_verbosity_error()
337
+
338
+ # Set the verbosity to info of the Transformers logger (on main process only):
339
+ logger.info(f"Training/evaluation parameters {training_args}")
340
+
341
+ # Handle the repository creation
342
+ if training_args.push_to_hub:
343
+ if training_args.hub_model_id is None:
344
+ repo_name = get_full_repo_name(
345
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
346
+ )
347
+ else:
348
+ repo_name = training_args.hub_model_id
349
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
350
+
351
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
352
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
353
+ # (the dataset will be downloaded automatically from the datasets Hub).
354
+ #
355
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
356
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
357
+ #
358
+ if data_args.pretokenized:
359
+ train_files = [f"train{i}.parquet" for i in range(3)]
360
+ dataset = load_dataset("parquet", data_files={"train": train_files, "validation": "val.parquet"})
361
+ elif data_args.dataset_name is not None:
362
+ # Downloading and loading a dataset from the hub.
363
+ dataset = load_dataset(
364
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
365
+ )
366
+ else:
367
+ data_files = {}
368
+ if data_args.train_file is not None:
369
+ data_files["train"] = data_args.train_file
370
+ extension = data_args.train_file.split(".")[-1]
371
+ if data_args.validation_file is not None:
372
+ data_files["validation"] = data_args.validation_file
373
+ extension = data_args.validation_file.split(".")[-1]
374
+ if data_args.test_file is not None:
375
+ data_files["test"] = data_args.test_file
376
+ extension = data_args.test_file.split(".")[-1]
377
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
378
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
379
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
380
+
381
+ # Load pretrained model and tokenizer
382
+
383
+ if model_args.config_name:
384
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
385
+ elif model_args.model_name_or_path:
386
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
387
+ else:
388
+ config = CONFIG_MAPPING[model_args.model_type]()
389
+ logger.warning("You are instantiating a new config instance from scratch.")
390
+
391
+ if model_args.tokenizer_name:
392
+ tokenizer = AutoTokenizer.from_pretrained(
393
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
394
+ )
395
+ elif model_args.model_name_or_path:
396
+ tokenizer = AutoTokenizer.from_pretrained(
397
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
398
+ )
399
+ else:
400
+ raise ValueError(
401
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
402
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
403
+ )
404
+
405
+ if model_args.model_name_or_path:
406
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
407
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
408
+ )
409
+ else:
410
+ model = FlaxAutoModelForSeq2SeqLM.from_config(
411
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
412
+ )
413
+
414
+ if model.config.decoder_start_token_id is None:
415
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
416
+
417
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
418
+
419
+ # Preprocessing the datasets.
420
+ # We need to tokenize inputs and targets.
421
+ if data_args.pretokenized:
422
+ column_names = ["context", "question"]
423
+ if training_args.do_train:
424
+ column_names = dataset["train"].column_names
425
+ elif training_args.do_eval:
426
+ column_names = dataset["validation"].column_names
427
+ elif training_args.do_predict:
428
+ column_names = dataset["test"].column_names
429
+ else:
430
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
431
+ return
432
+
433
+ # Get the column names for input/target.
434
+ dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
435
+ if data_args.text_column is None:
436
+ text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
437
+ else:
438
+ text_column = data_args.text_column
439
+ if text_column not in column_names:
440
+ raise ValueError(
441
+ f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
442
+ )
443
+ if data_args.summary_column is None:
444
+ summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
445
+ else:
446
+ summary_column = data_args.summary_column
447
+ if summary_column not in column_names:
448
+ raise ValueError(
449
+ f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
450
+ )
451
+
452
+ # Temporarily set max_target_length for training.
453
+ max_target_length = data_args.max_target_length
454
+
455
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
456
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
457
+ # for that dynamically import the `shift_tokens_right` function from the model file
458
+ # model_module = __import__(model.__module__, fromlist=["shift_tokens_right"])
459
+ # shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
460
+
461
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
462
+ def preprocess_function(examples):
463
+ inputs = examples[text_column]
464
+ targets = examples[summary_column]
465
+ inputs = [prefix + inp for inp in inputs]
466
+ model_inputs = tokenizer(
467
+ inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
468
+ )
469
+
470
+ # Setup the tokenizer for targets
471
+ with tokenizer.as_target_tokenizer():
472
+ labels = tokenizer(
473
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
474
+ )
475
+
476
+ model_inputs["labels"] = labels["input_ids"]
477
+ decoder_input_ids = shift_tokens_right_fn(
478
+ jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
479
+ )
480
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
481
+
482
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
483
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
484
+
485
+ return model_inputs
486
+
487
+ if data_args.pretokenized:
488
+ train_dataset = dataset["train"]
489
+ elif training_args.do_train:
490
+ if "train" not in dataset:
491
+ raise ValueError("--do_train requires a train dataset")
492
+ train_dataset = dataset["train"]
493
+ if data_args.max_train_samples is not None:
494
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
495
+ train_dataset = train_dataset.map(
496
+ preprocess_function,
497
+ batched=True,
498
+ num_proc=data_args.preprocessing_num_workers,
499
+ remove_columns=column_names,
500
+ load_from_cache_file=not data_args.overwrite_cache,
501
+ desc="Running tokenizer on train dataset",
502
+ )
503
+
504
+ if data_args.pretokenized:
505
+ eval_dataset = dataset["validation"]
506
+ elif training_args.do_eval:
507
+ max_target_length = data_args.val_max_target_length
508
+ if "validation" not in dataset:
509
+ raise ValueError("--do_eval requires a validation dataset")
510
+ eval_dataset = dataset["validation"]
511
+ if data_args.max_eval_samples is not None:
512
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
513
+ eval_dataset = eval_dataset.map(
514
+ preprocess_function,
515
+ batched=True,
516
+ num_proc=data_args.preprocessing_num_workers,
517
+ remove_columns=column_names,
518
+ load_from_cache_file=not data_args.overwrite_cache,
519
+ desc="Running tokenizer on validation dataset",
520
+ )
521
+
522
+ if training_args.do_predict:
523
+ max_target_length = data_args.val_max_target_length
524
+ if "test" not in dataset:
525
+ raise ValueError("--do_predict requires a test dataset")
526
+ predict_dataset = dataset["test"]
527
+ if data_args.max_predict_samples is not None:
528
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
529
+ predict_dataset = predict_dataset.map(
530
+ preprocess_function,
531
+ batched=True,
532
+ num_proc=data_args.preprocessing_num_workers,
533
+ remove_columns=column_names,
534
+ load_from_cache_file=not data_args.overwrite_cache,
535
+ desc="Running tokenizer on prediction dataset",
536
+ )
537
+
538
+ # Metric
539
+ metric = load_metric("rouge")
540
+
541
+ def postprocess_text(preds, labels):
542
+ preds = [pred.strip() for pred in preds]
543
+ labels = [label.strip() for label in labels]
544
+
545
+ # rougeLSum expects newline after each sentence
546
+ preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
547
+ labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
548
+
549
+ return preds, labels
550
+
551
+ def compute_metrics(preds, labels):
552
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
553
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
554
+
555
+ # Some simple post-processing
556
+ decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
557
+
558
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
559
+ # Extract a few results from ROUGE
560
+ result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
561
+
562
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
563
+ result["gen_len"] = np.mean(prediction_lens)
564
+ result = {k: round(v, 4) for k, v in result.items()}
565
+ return result
566
+
567
+ # Enable tensorboard only on the master node
568
+ has_tensorboard = is_tensorboard_available()
569
+ if has_tensorboard and jax.process_index() == 0:
570
+ try:
571
+ from flax.metrics.tensorboard import SummaryWriter
572
+
573
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
574
+ except ImportError as ie:
575
+ has_tensorboard = False
576
+ logger.warning(
577
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
578
+ )
579
+ else:
580
+ logger.warning(
581
+ "Unable to display metrics through TensorBoard because the package is not installed: "
582
+ "Please run pip install tensorboard to enable."
583
+ )
584
+
585
+ # Initialize our training
586
+ rng = jax.random.PRNGKey(training_args.seed)
587
+ rng, dropout_rng = jax.random.split(rng)
588
+
589
+ # Store some constant
590
+ num_epochs = int(training_args.num_train_epochs)
591
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
592
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
593
+ steps_per_epoch = len(train_dataset) // train_batch_size
594
+ total_train_steps = steps_per_epoch * num_epochs
595
+
596
+ # Create learning rate schedule
597
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
598
+ len(train_dataset),
599
+ train_batch_size,
600
+ training_args.num_train_epochs,
601
+ training_args.warmup_steps,
602
+ training_args.learning_rate,
603
+ )
604
+
605
+ # We use Optax's "masking" functionality to not apply weight decay
606
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
607
+ # mask boolean with the same structure as the parameters.
608
+ # The mask is True for parameters that should be decayed.
609
+ # Note that this mask is specifically adapted for FlaxBart.
610
+ # For FlaxT5, one should correct the layer norm parameter naming
611
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
612
+ def decay_mask_fn(params):
613
+ flat_params = traverse_util.flatten_dict(params)
614
+ layer_norm_params = [
615
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
616
+ ]
617
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
618
+ return traverse_util.unflatten_dict(flat_mask)
619
+
620
+ # create adam optimizer
621
+ adamw = optax.adamw(
622
+ learning_rate=linear_decay_lr_schedule_fn,
623
+ b1=training_args.adam_beta1,
624
+ b2=training_args.adam_beta2,
625
+ eps=training_args.adam_epsilon,
626
+ weight_decay=training_args.weight_decay,
627
+ mask=decay_mask_fn,
628
+ )
629
+
630
+ # Setup train state
631
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
632
+
633
+ # label smoothed cross entropy
634
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
635
+ """
636
+ The label smoothing implementation is adapted from Flax's official example:
637
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
638
+ """
639
+ vocab_size = logits.shape[-1]
640
+ confidence = 1.0 - label_smoothing_factor
641
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
642
+ normalizing_constant = -(
643
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
644
+ )
645
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
646
+
647
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
648
+ loss = loss - normalizing_constant
649
+
650
+ # ignore padded tokens from loss
651
+ loss = loss * padding_mask
652
+ loss = loss.sum() / padding_mask.sum()
653
+ return loss
654
+
655
+ # Define gradient update step fn
656
+ def train_step(state, batch, label_smoothing_factor=0.0):
657
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
658
+
659
+ def compute_loss(params):
660
+ labels = batch.pop("labels")
661
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
662
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
663
+ return loss
664
+
665
+ grad_fn = jax.value_and_grad(compute_loss)
666
+ loss, grad = grad_fn(state.params)
667
+ grad = jax.lax.pmean(grad, "batch")
668
+
669
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
670
+
671
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
672
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
673
+
674
+ return new_state, metrics
675
+
676
+ # Define eval fn
677
+ def eval_step(params, batch, label_smoothing_factor=0.0):
678
+ labels = batch.pop("labels")
679
+ logits = model(**batch, params=params, train=False)[0]
680
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
681
+
682
+ # summarize metrics
683
+ metrics = {"loss": loss}
684
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
685
+ return metrics
686
+
687
+ # Define generation function
688
+ max_length = (
689
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
690
+ )
691
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
692
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
693
+
694
+ def generate_step(params, batch):
695
+ model.params = params
696
+ output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
697
+ return output_ids.sequences
698
+
699
+ # Create parallel version of the train and eval step
700
+ p_train_step = jax.pmap(
701
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
702
+ )
703
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
704
+ p_generate_step = jax.pmap(generate_step, "batch")
705
+
706
+ # Replicate the train state on each device
707
+ state = state.replicate()
708
+
709
+ logger.info("***** Running training *****")
710
+ logger.info(f" Num examples = {len(train_dataset)}")
711
+ logger.info(f" Num Epochs = {num_epochs}")
712
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
713
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
714
+ logger.info(f" Total optimization steps = {total_train_steps}")
715
+
716
+ train_time = 0
717
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
718
+ for epoch in epochs:
719
+ # ======================== Training ================================
720
+ train_start = time.time()
721
+
722
+ # Create sampling rng
723
+ rng, input_rng = jax.random.split(rng)
724
+ train_metrics = []
725
+
726
+ # Generate an epoch by shuffling sampling indices from the train dataset
727
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
728
+ steps_per_epoch = len(train_dataset) // train_batch_size
729
+ # train
730
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
731
+ batch = next(train_loader)
732
+ state, train_metric = p_train_step(state, batch)
733
+ train_metrics.append(train_metric)
734
+
735
+ train_time += time.time() - train_start
736
+
737
+ train_metric = unreplicate(train_metric)
738
+
739
+ epochs.write(
740
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
741
+ )
742
+
743
+ # ======================== Evaluating ==============================
744
+ eval_metrics = []
745
+ eval_preds = []
746
+ eval_labels = []
747
+
748
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
749
+ eval_steps = len(eval_dataset) // eval_batch_size
750
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
751
+ # Model forward
752
+ batch = next(eval_loader)
753
+ labels = batch["labels"]
754
+
755
+ metrics = p_eval_step(state.params, batch)
756
+ eval_metrics.append(metrics)
757
+
758
+ # generation
759
+ if data_args.predict_with_generate:
760
+ generated_ids = p_generate_step(state.params, batch)
761
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
762
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
763
+
764
+ # normalize eval metrics
765
+ eval_metrics = get_metrics(eval_metrics)
766
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
767
+
768
+ # compute ROUGE metrics
769
+ rouge_desc = ""
770
+ if data_args.predict_with_generate:
771
+ rouge_metrics = compute_metrics(eval_preds, eval_labels)
772
+ eval_metrics.update(rouge_metrics)
773
+ rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
774
+
775
+ # Print metrics and update progress bar
776
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
777
+ epochs.write(desc)
778
+ epochs.desc = desc
779
+
780
+ # Save metrics
781
+ if has_tensorboard and jax.process_index() == 0:
782
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
783
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
784
+
785
+ # save checkpoint after each epoch and push checkpoint to the hub
786
+ if jax.process_index() == 0:
787
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
788
+ model.save_pretrained(training_args.output_dir, params=params)
789
+ tokenizer.save_pretrained(training_args.output_dir)
790
+ if training_args.push_to_hub:
791
+ repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
792
+
793
+ # ======================== Prediction loop ==============================
794
+ if training_args.do_predict:
795
+ logger.info("*** Predict ***")
796
+
797
+ pred_metrics = []
798
+ pred_generations = []
799
+ pred_labels = []
800
+
801
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
802
+ pred_steps = len(predict_dataset) // eval_batch_size
803
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
804
+ # Model forward
805
+ batch = next(pred_loader)
806
+ labels = batch["labels"]
807
+
808
+ metrics = p_eval_step(state.params, batch)
809
+ pred_metrics.append(metrics)
810
+
811
+ # generation
812
+ if data_args.predict_with_generate:
813
+ generated_ids = p_generate_step(state.params, batch)
814
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
815
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
816
+
817
+ # normalize prediction metrics
818
+ pred_metrics = get_metrics(pred_metrics)
819
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
820
+
821
+ # compute ROUGE metrics
822
+ rouge_desc = ""
823
+ if data_args.predict_with_generate:
824
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
825
+ pred_metrics.update(rouge_metrics)
826
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
827
+
828
+ # Print metrics
829
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
830
+ logger.info(desc)
831
+
832
+
833
+ if __name__ == "__main__":
834
+ main()