m3hrdadfi commited on
Commit
4c28b8d
1 Parent(s): c2af344

Add training/preparation scripts

Browse files
notes/.keep ADDED
File without changes
src/preparaing_recipe_nlg_dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ from datasets import load_dataset
12
+ from transformers import (
13
+ HfArgumentParser,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class DataArguments:
21
+ """
22
+ Arguments to which dataset we are going to set up.
23
+ """
24
+
25
+ output_dir: str = field(
26
+ default=".",
27
+ metadata={"help": "The output directory where the config will be written."},
28
+ )
29
+ dataset_name: str = field(
30
+ default=None,
31
+ metadata={"help": "The name of the dataset to use (via the datasets library)."}
32
+ )
33
+ dataset_data_dir: Optional[str] = field(
34
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
35
+ )
36
+ cache_dir: Optional[str] = field(
37
+ default=None,
38
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
39
+ )
40
+
41
+
42
+ def main():
43
+ parser = HfArgumentParser([DataArguments])
44
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
45
+ # If we pass only one argument to the script and it's the path to a json file,
46
+ # let's parse it to get our arguments.
47
+ data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
48
+ else:
49
+ data_args = parser.parse_args_into_dataclasses()[0]
50
+
51
+ # Setup logging
52
+ logging.basicConfig(
53
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
54
+ datefmt="%m/%d/%Y %H:%M:%S",
55
+ handlers=[logging.StreamHandler(sys.stdout)],
56
+ )
57
+ logger.setLevel(logging.INFO)
58
+
59
+ logger.info(f"Preparing the dataset")
60
+
61
+ if data_args.dataset_name is not None:
62
+ dataset = load_dataset(
63
+ data_args.dataset_name,
64
+ data_dir=data_args.dataset_data_dir,
65
+ cache_dir=data_args.cache_dir
66
+ )
67
+ else:
68
+ dataset = load_dataset(
69
+ data_args.dataset_name,
70
+ cache_dir=data_args.cache_dir
71
+ )
72
+
73
+ def cleaning(text, item_type="ner"):
74
+ # NOTE: DO THE CLEANING LATER
75
+ return text
76
+
77
+ def recipe_preparation(item_dict):
78
+ requirements = ["ner", "ingredients", "steps"]
79
+ constraints = [3, 3, 10]
80
+ if not all([
81
+ True if requirements[i] in item_dict and len(item_dict[requirements[i]].split()) > constraints[i] else False
82
+ for i in range(len(requirements))
83
+ ]):
84
+ return None
85
+
86
+ ner = cleaning(item_dict["ner"], "ner")
87
+ ingredients = cleaning(item_dict["ingredients"], "ingredients")
88
+ steps = cleaning(item_dict["steps"], "steps")
89
+
90
+ return {
91
+ "inputs": ner,
92
+ "targets": f"{ingredients}<sep>{steps}"
93
+ }
94
+
95
+ for subset in dataset.keys():
96
+ data_dict = []
97
+ for item in tqdm(dataset[subset], position=0, total=len(dataset[subset])):
98
+ item = recipe_preparation(item)
99
+ if item:
100
+ data_dict.append(item)
101
+
102
+ data_df = pd.DataFrame(data_dict)
103
+ logger.info(f"Preparation of [{subset}] set consists of {len(data_df)} records!")
104
+
105
+ output_path = os.path.join(data_args.output_dir, f"{subset}.csv")
106
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
107
+ data_df.to_csv(output_path, sep="\t", encoding="utf-8", index=False)
108
+ logger.info(f"Data saved here {output_path}")
109
+
110
+
111
+ if __name__ == '__main__':
112
+ main()
src/run.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export LC_ALL=C.UTF-8
4
+ export LANG=C.UTF-8
5
+
6
+ export OUTPUT_DIR=/to/our/path
7
+ export MODEL_NAME_OR_PATH=t5-base
8
+ export NUM_BEAMS=3
9
+
10
+ export TRAIN_FILE=/to/../train.csv
11
+ export VALIDATION_FILE=/to/../dev.csv
12
+ export TEST_FILE=/to/../test.csv
13
+ export TEXT_COLUMN=inputs
14
+ export TARGET_COLUMN=targets
15
+ export MAX_SOURCE_LENGTH=128
16
+ export MAX_TARGET_LENGTH=1024
17
+ export SOURCE_PREFIX=ingredients
18
+
19
+ export PER_DEVICE_TRAIN_BATCH_SIZE=8
20
+ export PER_DEVICE_EVAL_BATCH_SIZE=8
21
+ export GRADIENT_ACCUMULATION_STEPS=2
22
+ export NUM_TRAIN_EPOCHS=3.0
23
+ export LEARNING_RATE=5e-4
24
+ export WARMUP_STEPS=5000
25
+
26
+ python run_ed_recipe_nlg.py \
27
+ --output_dir="$OUTPUT_DIR" \
28
+ --train_file="$TRAIN_FILE" \
29
+ --validation_file="$VALIDATION_FILE" \
30
+ --test_file="$TEST_FILE" \
31
+ --text_column="$TEXT_COLUMN" \
32
+ --target_column="$TARGET_COLUMN" \
33
+ --source_prefix="$SOURCE_PREFIX: " \
34
+ --max_source_length="$MAX_SOURCE_LENGTH" \
35
+ --max_target_length="$MAX_TARGET_LENGTH" \
36
+ --model_name_or_path="$MODEL_NAME_OR_PATH" \
37
+ --extra_tokens="" \
38
+ --special_tokens="<sep>,<items>" \
39
+ --per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE \
40
+ --per_device_eval_batch_size=$PER_DEVICE_EVAL_BATCH_SIZE \
41
+ --gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
42
+ --num_train_epochs=$NUM_TRAIN_EPOCHS \
43
+ --learning_rate=$LEARNING_RATE \
44
+ --warmup_steps=$WARMUP_STEPS \
45
+ --preprocessing_num_workers=4 \
46
+ --prediction_debug \
47
+ --do_train \
48
+ --do_eval \
49
+ --do_predict \
50
+ --overwrite_output_dir \
51
+ --predict_with_generate
src/run_ed_recipe_nlg.py ADDED
@@ -0,0 +1,857 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for summarization.
18
+ """
19
+ # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
+
21
+ import logging
22
+ import os
23
+ import random
24
+ import sys
25
+ import time
26
+ from dataclasses import dataclass, field
27
+ from functools import partial
28
+ from pathlib import Path
29
+ from typing import Callable, Optional
30
+
31
+ import datasets
32
+ import nltk # Here to have a nice missing dependency error message early on
33
+ import numpy as np
34
+ from datasets import Dataset, load_dataset, load_metric
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ import jax.numpy as jnp
39
+ import optax
40
+ import transformers
41
+ from filelock import FileLock
42
+ from flax import jax_utils, traverse_util
43
+ from flax.jax_utils import unreplicate
44
+ from flax.training import train_state
45
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
46
+ from transformers import (
47
+ CONFIG_MAPPING,
48
+ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
49
+ AutoConfig,
50
+ AutoTokenizer,
51
+ FlaxAutoModelForSeq2SeqLM,
52
+ HfArgumentParser,
53
+ TrainingArguments,
54
+ is_tensorboard_available,
55
+ )
56
+ from transformers.file_utils import is_offline_mode
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+ try:
61
+ nltk.data.find("tokenizers/punkt")
62
+ except (LookupError, OSError):
63
+ if is_offline_mode():
64
+ raise LookupError(
65
+ "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
66
+ )
67
+ with FileLock(".lock") as lock:
68
+ nltk.download("punkt", quiet=True)
69
+
70
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
71
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
72
+
73
+
74
+ @dataclass
75
+ class ModelArguments:
76
+ """
77
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
78
+ """
79
+
80
+ model_name_or_path: Optional[str] = field(
81
+ default=None,
82
+ metadata={
83
+ "help": "The model checkpoint for weights initialization."
84
+ "Don't set if you want to train a model from scratch."
85
+ },
86
+ )
87
+ model_type: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
90
+ )
91
+ config_name: Optional[str] = field(
92
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
93
+ )
94
+ tokenizer_name: Optional[str] = field(
95
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
96
+ )
97
+ cache_dir: Optional[str] = field(
98
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
99
+ )
100
+ use_fast_tokenizer: bool = field(
101
+ default=True,
102
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
103
+ )
104
+ dtype: Optional[str] = field(
105
+ default="float32",
106
+ metadata={
107
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
108
+ },
109
+ )
110
+
111
+
112
+ @dataclass
113
+ class DataTrainingArguments:
114
+ """
115
+ Arguments pertaining to what data we are going to input our model for training and eval.
116
+ """
117
+
118
+ dataset_name: Optional[str] = field(
119
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
120
+ )
121
+ dataset_config_name: Optional[str] = field(
122
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
123
+ )
124
+ text_column: Optional[str] = field(
125
+ default=None,
126
+ metadata={"help": "The name of the column in the datasets containing the inputs (for generation)."},
127
+ )
128
+ target_column: Optional[str] = field(
129
+ default=None,
130
+ metadata={"help": "The name of the column in the datasets containing the targets (for generation)."},
131
+ )
132
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
133
+ validation_file: Optional[str] = field(
134
+ default=None,
135
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
136
+ )
137
+ max_source_length: Optional[int] = field(
138
+ default=128,
139
+ metadata={
140
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
141
+ "than this will be truncated, sequences shorter will be padded."
142
+ },
143
+ )
144
+ max_target_length: Optional[int] = field(
145
+ default=1024,
146
+ metadata={
147
+ "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
148
+ "than this will be truncated, sequences shorter will be padded."
149
+ },
150
+ )
151
+ val_max_target_length: Optional[int] = field(
152
+ default=None,
153
+ metadata={
154
+ "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
155
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
156
+ "This argument is also used to override the `max_length` param of `model.generate`, which is used "
157
+ "during evaluation."
158
+ },
159
+ )
160
+ max_train_samples: Optional[int] = field(
161
+ default=None,
162
+ metadata={
163
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
164
+ "value if set."
165
+ },
166
+ )
167
+ max_eval_samples: Optional[int] = field(
168
+ default=None,
169
+ metadata={
170
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
171
+ "value if set."
172
+ },
173
+ )
174
+ max_predict_samples: Optional[int] = field(
175
+ default=None,
176
+ metadata={
177
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
178
+ "value if set."
179
+ },
180
+ )
181
+ preprocessing_num_workers: Optional[int] = field(
182
+ default=None,
183
+ metadata={"help": "The number of processes to use for the preprocessing."},
184
+ )
185
+ source_prefix: Optional[str] = field(
186
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
187
+ )
188
+ predict_with_generate: bool = field(
189
+ default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
190
+ )
191
+ num_beams: Optional[int] = field(
192
+ default=None,
193
+ metadata={
194
+ "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
195
+ "which is used during evaluation."
196
+ },
197
+ )
198
+ overwrite_cache: bool = field(
199
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
200
+ )
201
+ extra_tokens: str = field(
202
+ default=None,
203
+ metadata={"help": "A text list of extra tokens separated by `,` that you want to add to the vocab."},
204
+ )
205
+ special_tokens: str = field(
206
+ default=None,
207
+ metadata={"help": "A list of special tokens separated by `,` that you want to add to the vocab."},
208
+ )
209
+ prediction_debug: bool = field(
210
+ default=False,
211
+ metadata={
212
+ "help": "Whether to show some examples of the model prediction"
213
+ },
214
+ )
215
+
216
+ def __post_init__(self):
217
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
218
+ raise ValueError("Need either a dataset name or a training/validation file.")
219
+ else:
220
+ if self.train_file is not None:
221
+ extension = self.train_file.split(".")[-1]
222
+ assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
223
+ if self.validation_file is not None:
224
+ extension = self.validation_file.split(".")[-1]
225
+ assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
226
+ if self.val_max_target_length is None:
227
+ self.val_max_target_length = self.max_target_length
228
+
229
+
230
+ class TrainState(train_state.TrainState):
231
+ dropout_rng: jnp.ndarray
232
+
233
+ def replicate(self):
234
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
235
+
236
+
237
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
238
+ """
239
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
240
+ Shuffle batches if `shuffle` is `True`.
241
+ """
242
+ steps_per_epoch = len(dataset) // batch_size
243
+
244
+ if shuffle:
245
+ batch_idx = jax.random.permutation(rng, len(dataset))
246
+ else:
247
+ batch_idx = jnp.arange(len(dataset))
248
+
249
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
250
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
251
+
252
+ for idx in batch_idx:
253
+ batch = dataset[idx]
254
+ batch = {k: jnp.array(v) for k, v in batch.items()}
255
+
256
+ batch = shard(batch)
257
+
258
+ yield batch
259
+
260
+
261
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
262
+ summary_writer.scalar("train_time", train_time, step)
263
+
264
+ train_metrics = get_metrics(train_metrics)
265
+ for key, vals in train_metrics.items():
266
+ tag = f"train_{key}"
267
+ for i, val in enumerate(vals):
268
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
269
+
270
+ for metric_name, value in eval_metrics.items():
271
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
272
+
273
+
274
+ def create_learning_rate_fn(
275
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
276
+ ) -> Callable[[int], jnp.array]:
277
+ """Returns a linear warmup, linear_decay learning rate function."""
278
+ steps_per_epoch = train_ds_size // train_batch_size
279
+ num_train_steps = steps_per_epoch * num_train_epochs
280
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
281
+ decay_fn = optax.linear_schedule(
282
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
283
+ )
284
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
285
+ return schedule_fn
286
+
287
+
288
+ def main():
289
+ # See all possible arguments in src/transformers/training_args.py
290
+ # or by passing the --help flag to this script.
291
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
292
+
293
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
294
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
295
+ # If we pass only one argument to the script and it's the path to a json file,
296
+ # let's parse it to get our arguments.
297
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
298
+ else:
299
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
300
+
301
+ if (
302
+ os.path.exists(training_args.output_dir)
303
+ and os.listdir(training_args.output_dir)
304
+ and training_args.do_train
305
+ and not training_args.overwrite_output_dir
306
+ ):
307
+ raise ValueError(
308
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
309
+ "Use --overwrite_output_dir to overcome."
310
+ )
311
+
312
+ # Make one log on every process with the configuration for debugging.
313
+ logging.basicConfig(
314
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
315
+ datefmt="%m/%d/%Y %H:%M:%S",
316
+ level=logging.INFO,
317
+ )
318
+ # Setup logging, we only want one process per machine to log things on the screen.
319
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
320
+ if jax.process_index() == 0:
321
+ datasets.utils.logging.set_verbosity_warning()
322
+ transformers.utils.logging.set_verbosity_info()
323
+ else:
324
+ datasets.utils.logging.set_verbosity_error()
325
+ transformers.utils.logging.set_verbosity_error()
326
+
327
+ # Set the verbosity to info of the Transformers logger (on main process only):
328
+ logger.info(f"Training/evaluation parameters {training_args}")
329
+
330
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
331
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
332
+ # (the dataset will be downloaded automatically from the datasets Hub).
333
+ #
334
+ # For CSV/JSON files this script will use the first column for the full texts and the second column for the
335
+ # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
336
+ #
337
+ if data_args.dataset_name is not None:
338
+ # Downloading and loading a dataset from the hub.
339
+ dataset = load_dataset(
340
+ data_args.dataset_name,
341
+ data_args.dataset_config_name,
342
+ cache_dir=model_args.cache_dir,
343
+ keep_in_memory=False
344
+ )
345
+ else:
346
+ data_files = {}
347
+
348
+ if data_args.train_file is not None:
349
+ data_files["train"] = data_args.train_file
350
+ extension = data_args.train_file.split(".")[-1]
351
+ if data_args.validation_file is not None:
352
+ data_files["validation"] = data_args.validation_file
353
+ extension = data_args.validation_file.split(".")[-1]
354
+ if data_args.test_file is not None:
355
+ data_files["test"] = data_args.test_file
356
+ extension = data_args.test_file.split(".")[-1]
357
+
358
+ dataset = load_dataset(
359
+ extension,
360
+ data_files=data_files,
361
+ delimiter="\t",
362
+ cache_dir=model_args.cache_dir
363
+ )
364
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
365
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
366
+
367
+ # Load pretrained model and tokenizer
368
+
369
+ if model_args.config_name:
370
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
371
+ elif model_args.model_name_or_path:
372
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
373
+ else:
374
+ config = CONFIG_MAPPING[model_args.model_type]()
375
+ logger.warning("You are instantiating a new config instance from scratch.")
376
+
377
+ if model_args.tokenizer_name:
378
+ tokenizer = AutoTokenizer.from_pretrained(
379
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
380
+ )
381
+ elif model_args.model_name_or_path:
382
+ tokenizer = AutoTokenizer.from_pretrained(
383
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
384
+ )
385
+ else:
386
+ raise ValueError(
387
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
388
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
389
+ )
390
+
391
+ if data_args.extra_tokens and isinstance(data_args.extra_tokens, str):
392
+ extra_tokens = list(data_args.extra_tokens.split(","))
393
+ if len(extra_tokens) > 0:
394
+ logger.info(f"*** Adding extra tokens: {extra_tokens} ***")
395
+ tokenizer.add_tokens(extra_tokens, special_tokens=False)
396
+
397
+ if data_args.special_tokens and isinstance(data_args.special_tokens, str):
398
+ special_tokens = list(data_args.special_tokens.split(","))
399
+ if len(special_tokens) > 0:
400
+ logger.info(f"*** Adding special tokens: {special_tokens} ***")
401
+ tokenizer.add_tokens(special_tokens, special_tokens=True)
402
+
403
+ if model_args.model_name_or_path:
404
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
405
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
406
+ )
407
+ else:
408
+ model = FlaxAutoModelForSeq2SeqLM.from_config(
409
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
410
+ )
411
+
412
+ model.resize_token_embeddings(len(tokenizer))
413
+ if model.config.decoder_start_token_id is None:
414
+ raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
415
+
416
+ prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
417
+
418
+ # Preprocessing the datasets.
419
+ # We need to tokenize inputs and targets.
420
+ if training_args.do_train:
421
+ column_names = dataset["train"].column_names
422
+ elif training_args.do_eval:
423
+ column_names = dataset["validation"].column_names
424
+ elif training_args.do_predict:
425
+ column_names = dataset["test"].column_names
426
+ else:
427
+ logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
428
+ return
429
+
430
+ # Get the column names for input/target.
431
+ if data_args.text_column is None:
432
+ text_column = column_names[0]
433
+ else:
434
+ text_column = data_args.text_column
435
+ if text_column not in column_names:
436
+ raise ValueError(
437
+ f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
438
+ )
439
+ if data_args.target_column is None:
440
+ target_column = column_names[1]
441
+ else:
442
+ target_column = data_args.target_column
443
+ if target_column not in column_names:
444
+ raise ValueError(
445
+ f"--target_column' value '{data_args.target_column}' needs to be one of: {', '.join(column_names)}"
446
+ )
447
+
448
+ # Temporarily set max_target_length for training.
449
+ max_target_length = data_args.max_target_length
450
+
451
+ # In Flax, for seq2seq models we need to pass `decoder_input_ids`
452
+ # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
453
+ # for that dynamically import the `shift_tokens_right` function from the model file
454
+ model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
455
+ shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
456
+
457
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
458
+ def preprocess_function(examples):
459
+ inputs = examples[text_column]
460
+ targets = examples[target_column]
461
+ inputs = [prefix + inp for inp in inputs]
462
+ model_inputs = tokenizer(
463
+ inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
464
+ )
465
+
466
+ # Setup the tokenizer for targets
467
+ with tokenizer.as_target_tokenizer():
468
+ labels = tokenizer(
469
+ targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
470
+ )
471
+
472
+ model_inputs["labels"] = labels["input_ids"]
473
+ decoder_input_ids = shift_tokens_right_fn(
474
+ jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
475
+ )
476
+ model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
477
+
478
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
479
+ model_inputs["decoder_attention_mask"] = labels["attention_mask"]
480
+
481
+ return model_inputs
482
+
483
+ if training_args.do_train:
484
+ if "train" not in dataset:
485
+ raise ValueError("--do_train requires a train dataset")
486
+ train_dataset = dataset["train"]
487
+ if data_args.max_train_samples is not None:
488
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
489
+ train_dataset = train_dataset.map(
490
+ preprocess_function,
491
+ batched=True,
492
+ num_proc=data_args.preprocessing_num_workers,
493
+ remove_columns=column_names,
494
+ load_from_cache_file=not data_args.overwrite_cache,
495
+ desc="Running tokenizer on train dataset",
496
+ )
497
+
498
+ if training_args.do_eval:
499
+ max_target_length = data_args.val_max_target_length
500
+ if "validation" not in dataset:
501
+ raise ValueError("--do_eval requires a validation dataset")
502
+ eval_dataset = dataset["validation"]
503
+ if data_args.max_eval_samples is not None:
504
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
505
+ eval_dataset = eval_dataset.map(
506
+ preprocess_function,
507
+ batched=True,
508
+ num_proc=data_args.preprocessing_num_workers,
509
+ remove_columns=column_names,
510
+ load_from_cache_file=not data_args.overwrite_cache,
511
+ desc="Running tokenizer on validation dataset",
512
+ )
513
+
514
+ if training_args.do_predict:
515
+ max_target_length = data_args.val_max_target_length
516
+ if "test" not in dataset:
517
+ raise ValueError("--do_predict requires a test dataset")
518
+ predict_dataset = dataset["test"]
519
+ if data_args.max_predict_samples is not None:
520
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
521
+ predict_dataset = predict_dataset.map(
522
+ preprocess_function,
523
+ batched=True,
524
+ num_proc=data_args.preprocessing_num_workers,
525
+ remove_columns=column_names,
526
+ load_from_cache_file=not data_args.overwrite_cache,
527
+ desc="Running tokenizer on prediction dataset",
528
+ )
529
+
530
+ # Metrics
531
+ bleu = load_metric("sacrebleu")
532
+ wer = load_metric("wer")
533
+
534
+ def postprocess_text(preds, labels):
535
+ preds = [pred.strip() for pred in preds]
536
+ labels_bleu = [[label.strip()] for label in labels]
537
+ labels_wer = [label.strip() for label in labels]
538
+
539
+ return preds, [labels_bleu, labels_wer]
540
+
541
+ def compute_metrics(preds, labels):
542
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
543
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
544
+
545
+ # Some simple post-processing
546
+ decoded_preds, [decoded_labels_bleu, decoded_labels_wer] = postprocess_text(decoded_preds, decoded_labels)
547
+
548
+ if data_args.prediction_debug:
549
+ for index in random.sample(range(len(decoded_labels)), 3):
550
+ logger.info(f'reference: "{decoded_labels[index]}"')
551
+ logger.info(f'predicted: "{decoded_preds[index]}"')
552
+ logger.info('---')
553
+
554
+ result = {}
555
+
556
+ try:
557
+ result_blue = bleu.compute(predictions=decoded_preds, references=decoded_labels_wer)
558
+ result_blue = result_blue["score"]
559
+ except Exception as e:
560
+ logger.info(f'Error occurred during bleu {e}')
561
+ result_blue = 0.0 * 100
562
+
563
+ try:
564
+ result_wer = wer.compute(predictions=decoded_preds, references=decoded_labels_wer)
565
+ result_wer = result_wer * 100
566
+ except Exception as e:
567
+ logger.info(f'Error occurred during wer {e}')
568
+ result_wer = 1.0 * 100
569
+
570
+ result["blue"] = result_blue
571
+ result["wer"] = result_wer
572
+
573
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
574
+ result["gen_len"] = np.mean(prediction_lens)
575
+ result = {k: round(v, 4) for k, v in result.items()}
576
+ return result
577
+
578
+ # Enable tensorboard only on the master node
579
+ has_tensorboard = is_tensorboard_available()
580
+ if has_tensorboard and jax.process_index() == 0:
581
+ try:
582
+ from flax.metrics.tensorboard import SummaryWriter
583
+
584
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
585
+ except ImportError as ie:
586
+ has_tensorboard = False
587
+ logger.warning(
588
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
589
+ )
590
+ else:
591
+ logger.warning(
592
+ "Unable to display metrics through TensorBoard because the package is not installed: "
593
+ "Please run pip install tensorboard to enable."
594
+ )
595
+
596
+ # Initialize our training
597
+ rng = jax.random.PRNGKey(training_args.seed)
598
+ rng, dropout_rng = jax.random.split(rng)
599
+
600
+ # Store some constant
601
+ num_epochs = int(training_args.num_train_epochs)
602
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
603
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
604
+ steps_per_epoch = len(train_dataset) // train_batch_size
605
+ total_train_steps = steps_per_epoch * num_epochs
606
+
607
+ # Create learning rate schedule
608
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
609
+ len(train_dataset),
610
+ train_batch_size,
611
+ training_args.num_train_epochs,
612
+ training_args.warmup_steps,
613
+ training_args.learning_rate,
614
+ )
615
+
616
+ # We use Optax's "masking" functionality to not apply weight decay
617
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
618
+ # mask boolean with the same structure as the parameters.
619
+ # The mask is True for parameters that should be decayed.
620
+ # Note that this mask is specifically adapted for FlaxBart.
621
+ # For FlaxT5, one should correct the layer norm parameter naming
622
+ # accordingly - see `run_t5_mlm_flax.py` e.g.
623
+
624
+ if any(x in model_args.model_name_or_path for x in ["t5", "mt5", "byt5"]):
625
+ def decay_mask_fn(params):
626
+ flat_params = traverse_util.flatten_dict(params)
627
+ layer_norm_params = [
628
+ (name, "scale") for name in ["layer_norm", "final_layer_norm"]
629
+ ]
630
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
631
+ return traverse_util.unflatten_dict(flat_mask)
632
+ else:
633
+ def decay_mask_fn(params):
634
+ flat_params = traverse_util.flatten_dict(params)
635
+ layer_norm_params = [
636
+ (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
637
+ ]
638
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
639
+ return traverse_util.unflatten_dict(flat_mask)
640
+
641
+ # create adam optimizer
642
+ adamw = optax.adamw(
643
+ learning_rate=linear_decay_lr_schedule_fn,
644
+ b1=training_args.adam_beta1,
645
+ b2=training_args.adam_beta2,
646
+ eps=training_args.adam_epsilon,
647
+ weight_decay=training_args.weight_decay,
648
+ mask=decay_mask_fn,
649
+ )
650
+
651
+ # Setup train state
652
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
653
+
654
+ # label smoothed cross entropy
655
+ def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
656
+ """
657
+ The label smoothing implementation is adapted from Flax's official example:
658
+ https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
659
+ """
660
+ vocab_size = logits.shape[-1]
661
+ confidence = 1.0 - label_smoothing_factor
662
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
663
+ normalizing_constant = -(
664
+ confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
665
+ )
666
+ soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
667
+
668
+ loss = optax.softmax_cross_entropy(logits, soft_labels)
669
+ loss = loss - normalizing_constant
670
+
671
+ # ignore padded tokens from loss
672
+ loss = loss * padding_mask
673
+ loss = loss.sum() / padding_mask.sum()
674
+ return loss
675
+
676
+ # Define gradient update step fn
677
+ def train_step(state, batch, label_smoothing_factor=0.0):
678
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
679
+
680
+ def compute_loss(params):
681
+ labels = batch.pop("labels")
682
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
683
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
684
+ return loss
685
+
686
+ grad_fn = jax.value_and_grad(compute_loss)
687
+ loss, grad = grad_fn(state.params)
688
+ grad = jax.lax.pmean(grad, "batch")
689
+
690
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
691
+
692
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
693
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
694
+
695
+ return new_state, metrics
696
+
697
+ # Define eval fn
698
+ def eval_step(params, batch, label_smoothing_factor=0.0):
699
+ labels = batch.pop("labels")
700
+ logits = model(**batch, params=params, train=False)[0]
701
+ loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
702
+
703
+ # summarize metrics
704
+ metrics = {"loss": loss}
705
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
706
+ return metrics
707
+
708
+ # Define generation function
709
+ max_length = (
710
+ data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
711
+ )
712
+ num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
713
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
714
+
715
+ def generate_step(params, batch):
716
+ model.params = params
717
+ output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
718
+ return output_ids.sequences
719
+
720
+ # Create parallel version of the train and eval step
721
+ p_train_step = jax.pmap(
722
+ partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
723
+ )
724
+ p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
725
+ p_generate_step = jax.pmap(generate_step, "batch")
726
+
727
+ # Replicate the train state on each device
728
+ state = state.replicate()
729
+
730
+ logger.info("***** Running training *****")
731
+ logger.info(f" Num examples = {len(train_dataset)}")
732
+ logger.info(f" Num Epochs = {num_epochs}")
733
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
734
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
735
+ logger.info(f" Total optimization steps = {total_train_steps}")
736
+
737
+ train_time = 0
738
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
739
+ for epoch in epochs:
740
+ # ======================== Training ================================
741
+ train_start = time.time()
742
+
743
+ # Create sampling rng
744
+ rng, input_rng = jax.random.split(rng)
745
+ train_metrics = []
746
+
747
+ # Generate an epoch by shuffling sampling indices from the train dataset
748
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
749
+ steps_per_epoch = len(train_dataset) // train_batch_size
750
+ # train
751
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
752
+ batch = next(train_loader)
753
+ state, train_metric = p_train_step(state, batch)
754
+ train_metrics.append(train_metric)
755
+
756
+ train_time += time.time() - train_start
757
+
758
+ train_metric = unreplicate(train_metric)
759
+
760
+ epochs.write(
761
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
762
+ )
763
+
764
+ # ======================== Evaluating ==============================
765
+ eval_metrics = []
766
+ eval_preds = []
767
+ eval_labels = []
768
+
769
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
770
+ eval_steps = len(eval_dataset) // eval_batch_size
771
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
772
+ # Model forward
773
+ batch = next(eval_loader)
774
+ labels = batch["labels"]
775
+
776
+ metrics = p_eval_step(state.params, batch)
777
+ eval_metrics.append(metrics)
778
+
779
+ # generation
780
+ if data_args.predict_with_generate:
781
+ generated_ids = p_generate_step(state.params, batch)
782
+ eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
783
+ eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
784
+
785
+ # normalize eval metrics
786
+ eval_metrics = get_metrics(eval_metrics)
787
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
788
+
789
+ # compute ROUGE metrics
790
+ rouge_desc = ""
791
+ if data_args.predict_with_generate:
792
+ rouge_metrics = compute_metrics(eval_preds, eval_labels)
793
+ eval_metrics.update(rouge_metrics)
794
+ rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
795
+
796
+ # Print metrics and update progress bar
797
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
798
+ epochs.write(desc)
799
+ epochs.desc = desc
800
+
801
+ # Save metrics
802
+ if has_tensorboard and jax.process_index() == 0:
803
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
804
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
805
+
806
+ # ======================== Prediction loop ==============================
807
+ if training_args.do_predict:
808
+ logger.info("*** Predict ***")
809
+
810
+ pred_metrics = []
811
+ pred_generations = []
812
+ pred_labels = []
813
+
814
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
815
+ pred_steps = len(predict_dataset) // eval_batch_size
816
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
817
+ # Model forward
818
+ batch = next(pred_loader)
819
+ labels = batch["labels"]
820
+
821
+ metrics = p_eval_step(state.params, batch)
822
+ pred_metrics.append(metrics)
823
+
824
+ # generation
825
+ if data_args.predict_with_generate:
826
+ generated_ids = p_generate_step(state.params, batch)
827
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
828
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
829
+
830
+ # normalize prediction metrics
831
+ pred_metrics = get_metrics(pred_metrics)
832
+ pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
833
+
834
+ # compute ROUGE metrics
835
+ rouge_desc = ""
836
+ if data_args.predict_with_generate:
837
+ rouge_metrics = compute_metrics(pred_generations, pred_labels)
838
+ pred_metrics.update(rouge_metrics)
839
+ rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
840
+
841
+ # Print metrics
842
+ desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
843
+ logger.info(desc)
844
+
845
+ # save checkpoint after each epoch and push checkpoint to the hub
846
+ if jax.process_index() == 0:
847
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
848
+ model.save_pretrained(
849
+ training_args.output_dir,
850
+ params=params,
851
+ push_to_hub=training_args.push_to_hub,
852
+ commit_message=f"Saving weights and logs of epoch {epoch + 1}",
853
+ )
854
+
855
+
856
+ if __name__ == "__main__":
857
+ main()