Hooman Sedghamiz
commited on
Commit
•
01ae861
1
Parent(s):
3a36a39
pushing a template clm training script for gpt2
Browse files- src/run_clm_flax.py +625 -0
src/run_clm_flax.py
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
18 |
+
|
19 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
20 |
+
https://huggingface.co/models?filter=causal-lm
|
21 |
+
"""
|
22 |
+
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import math
|
26 |
+
import os
|
27 |
+
import sys
|
28 |
+
import time
|
29 |
+
from dataclasses import dataclass, field
|
30 |
+
from pathlib import Path
|
31 |
+
from typing import Callable, Optional
|
32 |
+
|
33 |
+
import datasets
|
34 |
+
from datasets import Dataset, load_dataset
|
35 |
+
from tqdm import tqdm
|
36 |
+
|
37 |
+
import jax
|
38 |
+
import jax.numpy as jnp
|
39 |
+
import optax
|
40 |
+
import transformers
|
41 |
+
from flax import jax_utils, traverse_util
|
42 |
+
from flax.jax_utils import unreplicate
|
43 |
+
from flax.training import train_state
|
44 |
+
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
45 |
+
from transformers import (
|
46 |
+
CONFIG_MAPPING,
|
47 |
+
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
48 |
+
AutoConfig,
|
49 |
+
AutoTokenizer,
|
50 |
+
FlaxAutoModelForCausalLM,
|
51 |
+
HfArgumentParser,
|
52 |
+
TrainingArguments,
|
53 |
+
is_tensorboard_available,
|
54 |
+
)
|
55 |
+
from transformers.testing_utils import CaptureLogger
|
56 |
+
|
57 |
+
|
58 |
+
logger = logging.getLogger(__name__)
|
59 |
+
|
60 |
+
# Cache the result
|
61 |
+
has_tensorboard = is_tensorboard_available()
|
62 |
+
if has_tensorboard:
|
63 |
+
try:
|
64 |
+
from flax.metrics.tensorboard import SummaryWriter
|
65 |
+
except ImportError as ie:
|
66 |
+
has_tensorboard = False
|
67 |
+
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
68 |
+
|
69 |
+
else:
|
70 |
+
print(
|
71 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
72 |
+
"Please run pip install tensorboard to enable."
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
77 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass
|
81 |
+
class ModelArguments:
|
82 |
+
"""
|
83 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
84 |
+
"""
|
85 |
+
|
86 |
+
model_name_or_path: Optional[str] = field(
|
87 |
+
default=None,
|
88 |
+
metadata={
|
89 |
+
"help": "The model checkpoint for weights initialization."
|
90 |
+
"Don't set if you want to train a model from scratch."
|
91 |
+
},
|
92 |
+
)
|
93 |
+
model_type: Optional[str] = field(
|
94 |
+
default=None,
|
95 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
96 |
+
)
|
97 |
+
config_name: Optional[str] = field(
|
98 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
99 |
+
)
|
100 |
+
tokenizer_name: Optional[str] = field(
|
101 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
102 |
+
)
|
103 |
+
cache_dir: Optional[str] = field(
|
104 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
105 |
+
)
|
106 |
+
use_fast_tokenizer: bool = field(
|
107 |
+
default=True,
|
108 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
109 |
+
)
|
110 |
+
dtype: Optional[str] = field(
|
111 |
+
default="float32",
|
112 |
+
metadata={
|
113 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
114 |
+
},
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
@dataclass
|
119 |
+
class DataTrainingArguments:
|
120 |
+
"""
|
121 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
122 |
+
"""
|
123 |
+
|
124 |
+
dataset_name: Optional[str] = field(
|
125 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
126 |
+
)
|
127 |
+
dataset_config_name: Optional[str] = field(
|
128 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
129 |
+
)
|
130 |
+
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
131 |
+
validation_file: Optional[str] = field(
|
132 |
+
default=None,
|
133 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
134 |
+
)
|
135 |
+
max_train_samples: Optional[int] = field(
|
136 |
+
default=None,
|
137 |
+
metadata={
|
138 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
139 |
+
"value if set."
|
140 |
+
},
|
141 |
+
)
|
142 |
+
max_eval_samples: Optional[int] = field(
|
143 |
+
default=None,
|
144 |
+
metadata={
|
145 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
146 |
+
"value if set."
|
147 |
+
},
|
148 |
+
)
|
149 |
+
overwrite_cache: bool = field(
|
150 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
151 |
+
)
|
152 |
+
validation_split_percentage: Optional[int] = field(
|
153 |
+
default=5,
|
154 |
+
metadata={
|
155 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
156 |
+
},
|
157 |
+
)
|
158 |
+
block_size: Optional[int] = field(
|
159 |
+
default=None,
|
160 |
+
metadata={
|
161 |
+
"help": "Optional input sequence length after tokenization. "
|
162 |
+
"The training dataset will be truncated in block of this size for training. "
|
163 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
164 |
+
},
|
165 |
+
)
|
166 |
+
overwrite_cache: bool = field(
|
167 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
168 |
+
)
|
169 |
+
preprocessing_num_workers: Optional[int] = field(
|
170 |
+
default=None,
|
171 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
172 |
+
)
|
173 |
+
|
174 |
+
def __post_init__(self):
|
175 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
176 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
177 |
+
else:
|
178 |
+
if self.train_file is not None:
|
179 |
+
extension = self.train_file.split(".")[-1]
|
180 |
+
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
181 |
+
if self.validation_file is not None:
|
182 |
+
extension = self.validation_file.split(".")[-1]
|
183 |
+
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
184 |
+
|
185 |
+
|
186 |
+
class TrainState(train_state.TrainState):
|
187 |
+
dropout_rng: jnp.ndarray
|
188 |
+
|
189 |
+
def replicate(self):
|
190 |
+
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
191 |
+
|
192 |
+
|
193 |
+
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
|
194 |
+
"""
|
195 |
+
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
196 |
+
Shuffle batches if `shuffle` is `True`.
|
197 |
+
"""
|
198 |
+
steps_per_epoch = len(dataset) // batch_size
|
199 |
+
|
200 |
+
if shuffle:
|
201 |
+
batch_idx = jax.random.permutation(rng, len(dataset))
|
202 |
+
else:
|
203 |
+
batch_idx = jnp.arange(len(dataset))
|
204 |
+
|
205 |
+
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
206 |
+
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
207 |
+
|
208 |
+
for idx in batch_idx:
|
209 |
+
batch = dataset[idx]
|
210 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
211 |
+
|
212 |
+
batch = shard(batch)
|
213 |
+
|
214 |
+
yield batch
|
215 |
+
|
216 |
+
|
217 |
+
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
218 |
+
summary_writer.scalar("train_time", train_time, step)
|
219 |
+
|
220 |
+
train_metrics = get_metrics(train_metrics)
|
221 |
+
for key, vals in train_metrics.items():
|
222 |
+
tag = f"train_{key}"
|
223 |
+
for i, val in enumerate(vals):
|
224 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
225 |
+
|
226 |
+
for metric_name, value in eval_metrics.items():
|
227 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
228 |
+
|
229 |
+
|
230 |
+
def create_learning_rate_fn(
|
231 |
+
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
232 |
+
) -> Callable[[int], jnp.array]:
|
233 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
234 |
+
steps_per_epoch = train_ds_size // train_batch_size
|
235 |
+
num_train_steps = steps_per_epoch * num_train_epochs
|
236 |
+
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
237 |
+
decay_fn = optax.linear_schedule(
|
238 |
+
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
239 |
+
)
|
240 |
+
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
241 |
+
return schedule_fn
|
242 |
+
|
243 |
+
|
244 |
+
def main():
|
245 |
+
# See all possible arguments in src/transformers/training_args.py
|
246 |
+
# or by passing the --help flag to this script.
|
247 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
248 |
+
|
249 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
250 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
251 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
252 |
+
# let's parse it to get our arguments.
|
253 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
254 |
+
else:
|
255 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
256 |
+
|
257 |
+
if (
|
258 |
+
os.path.exists(training_args.output_dir)
|
259 |
+
and os.listdir(training_args.output_dir)
|
260 |
+
and training_args.do_train
|
261 |
+
and not training_args.overwrite_output_dir
|
262 |
+
):
|
263 |
+
raise ValueError(
|
264 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
265 |
+
"Use --overwrite_output_dir to overcome."
|
266 |
+
)
|
267 |
+
|
268 |
+
# Make one log on every process with the configuration for debugging.
|
269 |
+
logging.basicConfig(
|
270 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
271 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
272 |
+
level=logging.INFO,
|
273 |
+
)
|
274 |
+
# Setup logging, we only want one process per machine to log things on the screen.
|
275 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
276 |
+
if jax.process_index() == 0:
|
277 |
+
datasets.utils.logging.set_verbosity_warning()
|
278 |
+
transformers.utils.logging.set_verbosity_info()
|
279 |
+
else:
|
280 |
+
datasets.utils.logging.set_verbosity_error()
|
281 |
+
transformers.utils.logging.set_verbosity_error()
|
282 |
+
|
283 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
284 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
285 |
+
|
286 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
287 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
288 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
289 |
+
#
|
290 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
291 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
292 |
+
#
|
293 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
294 |
+
# download the dataset.
|
295 |
+
if data_args.dataset_name is not None:
|
296 |
+
# Downloading and loading a dataset from the hub.
|
297 |
+
dataset = load_dataset(
|
298 |
+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
|
299 |
+
)
|
300 |
+
|
301 |
+
if "validation" not in dataset.keys():
|
302 |
+
dataset["validation"] = load_dataset(
|
303 |
+
data_args.dataset_name,
|
304 |
+
data_args.dataset_config_name,
|
305 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
306 |
+
cache_dir=model_args.cache_dir,
|
307 |
+
)
|
308 |
+
dataset["train"] = load_dataset(
|
309 |
+
data_args.dataset_name,
|
310 |
+
data_args.dataset_config_name,
|
311 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
312 |
+
cache_dir=model_args.cache_dir,
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
data_files = {}
|
316 |
+
if data_args.train_file is not None:
|
317 |
+
data_files["train"] = data_args.train_file
|
318 |
+
if data_args.validation_file is not None:
|
319 |
+
data_files["validation"] = data_args.validation_file
|
320 |
+
extension = data_args.train_file.split(".")[-1]
|
321 |
+
if extension == "txt":
|
322 |
+
extension = "text"
|
323 |
+
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
324 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
325 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
326 |
+
|
327 |
+
# Load pretrained model and tokenizer
|
328 |
+
|
329 |
+
# Distributed training:
|
330 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
331 |
+
# download model & vocab.
|
332 |
+
if model_args.config_name:
|
333 |
+
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
334 |
+
elif model_args.model_name_or_path:
|
335 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
336 |
+
else:
|
337 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
338 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
339 |
+
|
340 |
+
if model_args.tokenizer_name:
|
341 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
342 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
343 |
+
)
|
344 |
+
elif model_args.model_name_or_path:
|
345 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
346 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
raise ValueError(
|
350 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
351 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
352 |
+
)
|
353 |
+
|
354 |
+
if model_args.model_name_or_path:
|
355 |
+
model = FlaxAutoModelForCausalLM.from_pretrained(
|
356 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
model = FlaxAutoModelForCausalLM.from_config(
|
360 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
361 |
+
)
|
362 |
+
|
363 |
+
# Preprocessing the datasets.
|
364 |
+
# First we tokenize all the texts.
|
365 |
+
if training_args.do_train:
|
366 |
+
column_names = dataset["train"].column_names
|
367 |
+
else:
|
368 |
+
column_names = dataset["validation"].column_names
|
369 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
370 |
+
|
371 |
+
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
372 |
+
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
373 |
+
|
374 |
+
def tokenize_function(examples):
|
375 |
+
with CaptureLogger(tok_logger) as cl:
|
376 |
+
output = tokenizer(examples[text_column_name])
|
377 |
+
# clm input could be much much longer than block_size
|
378 |
+
if "Token indices sequence length is longer than the" in cl.out:
|
379 |
+
tok_logger.warning(
|
380 |
+
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
381 |
+
)
|
382 |
+
return output
|
383 |
+
|
384 |
+
tokenized_datasets = dataset.map(
|
385 |
+
tokenize_function,
|
386 |
+
batched=True,
|
387 |
+
num_proc=data_args.preprocessing_num_workers,
|
388 |
+
remove_columns=column_names,
|
389 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
390 |
+
)
|
391 |
+
|
392 |
+
if data_args.block_size is None:
|
393 |
+
block_size = tokenizer.model_max_length
|
394 |
+
if block_size > config.max_position_embeddings:
|
395 |
+
logger.warning(
|
396 |
+
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
397 |
+
"Picking 1024 instead. You can change that default value by passing --block_size xxx."
|
398 |
+
)
|
399 |
+
block_size = 1024
|
400 |
+
else:
|
401 |
+
if data_args.block_size > tokenizer.model_max_length:
|
402 |
+
logger.warning(
|
403 |
+
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
|
404 |
+
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
405 |
+
)
|
406 |
+
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
407 |
+
|
408 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
409 |
+
def group_texts(examples):
|
410 |
+
# Concatenate all texts.
|
411 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
412 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
413 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
414 |
+
# customize this part to your needs.
|
415 |
+
total_length = (total_length // block_size) * block_size
|
416 |
+
# Split by chunks of max_len.
|
417 |
+
result = {
|
418 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
419 |
+
for k, t in concatenated_examples.items()
|
420 |
+
}
|
421 |
+
result["labels"] = result["input_ids"].copy()
|
422 |
+
return result
|
423 |
+
|
424 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
425 |
+
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
426 |
+
# to preprocess.
|
427 |
+
#
|
428 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
429 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
430 |
+
|
431 |
+
lm_datasets = tokenized_datasets.map(
|
432 |
+
group_texts,
|
433 |
+
batched=True,
|
434 |
+
num_proc=data_args.preprocessing_num_workers,
|
435 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
436 |
+
)
|
437 |
+
|
438 |
+
if training_args.do_train:
|
439 |
+
if "train" not in tokenized_datasets:
|
440 |
+
raise ValueError("--do_train requires a train dataset")
|
441 |
+
train_dataset = lm_datasets["train"]
|
442 |
+
if data_args.max_train_samples is not None:
|
443 |
+
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
444 |
+
|
445 |
+
if training_args.do_eval:
|
446 |
+
if "validation" not in tokenized_datasets:
|
447 |
+
raise ValueError("--do_eval requires a validation dataset")
|
448 |
+
eval_dataset = lm_datasets["validation"]
|
449 |
+
if data_args.max_eval_samples is not None:
|
450 |
+
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
451 |
+
|
452 |
+
# Enable tensorboard only on the master node
|
453 |
+
if has_tensorboard and jax.process_index() == 0:
|
454 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
455 |
+
|
456 |
+
# Initialize our training
|
457 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
458 |
+
rng, dropout_rng = jax.random.split(rng)
|
459 |
+
|
460 |
+
# Store some constant
|
461 |
+
num_epochs = int(training_args.num_train_epochs)
|
462 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
463 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
464 |
+
steps_per_epoch = len(train_dataset) // train_batch_size
|
465 |
+
total_train_steps = steps_per_epoch * num_epochs
|
466 |
+
|
467 |
+
# Create learning rate schedule
|
468 |
+
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
469 |
+
len(train_dataset),
|
470 |
+
train_batch_size,
|
471 |
+
training_args.num_train_epochs,
|
472 |
+
training_args.warmup_steps,
|
473 |
+
training_args.learning_rate,
|
474 |
+
)
|
475 |
+
|
476 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
477 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
478 |
+
# mask boolean with the same structure as the parameters.
|
479 |
+
# The mask is True for parameters that should be decayed.
|
480 |
+
# Note that this mask is specifically adapted for FlaxGPT2.
|
481 |
+
# For other models, one should correct the layer norm parameter naming
|
482 |
+
# accordingly.
|
483 |
+
def decay_mask_fn(params):
|
484 |
+
flat_params = traverse_util.flatten_dict(params)
|
485 |
+
flat_mask = {
|
486 |
+
path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
|
487 |
+
for path in flat_params
|
488 |
+
}
|
489 |
+
return traverse_util.unflatten_dict(flat_mask)
|
490 |
+
|
491 |
+
# create adam optimizer
|
492 |
+
adamw = optax.adamw(
|
493 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
494 |
+
b1=training_args.adam_beta1,
|
495 |
+
b2=training_args.adam_beta2,
|
496 |
+
eps=training_args.adam_epsilon,
|
497 |
+
weight_decay=training_args.weight_decay,
|
498 |
+
mask=decay_mask_fn,
|
499 |
+
)
|
500 |
+
|
501 |
+
# Setup train state
|
502 |
+
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
503 |
+
|
504 |
+
def loss_fn(logits, labels):
|
505 |
+
shift_logits = logits[..., :-1, :]
|
506 |
+
shift_labels = labels[..., 1:]
|
507 |
+
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
|
508 |
+
return loss.mean()
|
509 |
+
|
510 |
+
# Define gradient update step fn
|
511 |
+
def train_step(state, batch):
|
512 |
+
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
513 |
+
|
514 |
+
def compute_loss(params):
|
515 |
+
labels = batch.pop("labels")
|
516 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
517 |
+
loss = loss_fn(logits, labels)
|
518 |
+
return loss
|
519 |
+
|
520 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
521 |
+
loss, grad = grad_fn(state.params)
|
522 |
+
grad = jax.lax.pmean(grad, "batch")
|
523 |
+
|
524 |
+
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
525 |
+
|
526 |
+
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
527 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
528 |
+
|
529 |
+
return new_state, metrics
|
530 |
+
|
531 |
+
# Define eval fn
|
532 |
+
def eval_step(params, batch):
|
533 |
+
labels = batch.pop("labels")
|
534 |
+
logits = model(**batch, params=params, train=False)[0]
|
535 |
+
loss = loss_fn(logits, labels)
|
536 |
+
|
537 |
+
# summarize metrics
|
538 |
+
metrics = {"loss": loss}
|
539 |
+
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
540 |
+
return metrics
|
541 |
+
|
542 |
+
# Create parallel version of the train and eval step
|
543 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
544 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
545 |
+
|
546 |
+
# Replicate the train state on each device
|
547 |
+
state = state.replicate()
|
548 |
+
|
549 |
+
logger.info("***** Running training *****")
|
550 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
551 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
552 |
+
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
553 |
+
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
554 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
555 |
+
|
556 |
+
train_time = 0
|
557 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
558 |
+
for epoch in epochs:
|
559 |
+
# ======================== Training ================================
|
560 |
+
train_start = time.time()
|
561 |
+
|
562 |
+
# Create sampling rng
|
563 |
+
rng, input_rng = jax.random.split(rng)
|
564 |
+
train_metrics = []
|
565 |
+
|
566 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
567 |
+
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
568 |
+
steps_per_epoch = len(train_dataset) // train_batch_size
|
569 |
+
# train
|
570 |
+
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
571 |
+
batch = next(train_loader)
|
572 |
+
state, train_metric = p_train_step(state, batch)
|
573 |
+
train_metrics.append(train_metric)
|
574 |
+
|
575 |
+
train_time += time.time() - train_start
|
576 |
+
|
577 |
+
train_metric = unreplicate(train_metric)
|
578 |
+
|
579 |
+
epochs.write(
|
580 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
581 |
+
)
|
582 |
+
|
583 |
+
# ======================== Evaluating ==============================
|
584 |
+
eval_metrics = []
|
585 |
+
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
586 |
+
eval_steps = len(eval_dataset) // eval_batch_size
|
587 |
+
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
588 |
+
# Model forward
|
589 |
+
batch = next(eval_loader)
|
590 |
+
metrics = p_eval_step(state.params, batch)
|
591 |
+
eval_metrics.append(metrics)
|
592 |
+
|
593 |
+
# normalize eval metrics
|
594 |
+
eval_metrics = get_metrics(eval_metrics)
|
595 |
+
|
596 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
597 |
+
|
598 |
+
try:
|
599 |
+
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
600 |
+
except OverflowError:
|
601 |
+
eval_metrics["perplexity"] = float("inf")
|
602 |
+
|
603 |
+
# Print metrics and update progress bar
|
604 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
|
605 |
+
epochs.write(desc)
|
606 |
+
epochs.desc = desc
|
607 |
+
|
608 |
+
# Save metrics
|
609 |
+
if has_tensorboard and jax.process_index() == 0:
|
610 |
+
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
611 |
+
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
612 |
+
|
613 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
614 |
+
if jax.process_index() == 0:
|
615 |
+
params = jax.device_get(unreplicate(state.params))
|
616 |
+
model.save_pretrained(
|
617 |
+
training_args.output_dir,
|
618 |
+
params=params,
|
619 |
+
push_to_hub=training_args.push_to_hub,
|
620 |
+
commit_message=f"Saving weights and logs of epoch {epoch+1}",
|
621 |
+
)
|
622 |
+
|
623 |
+
|
624 |
+
if __name__ == "__main__":
|
625 |
+
main()
|