m3hrdadfi commited on
Commit
7cfca48
1 Parent(s): 8812e32

Add training script with checkpoint and preprocessing + merge scripts

Browse files
src/data_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hazm import word_tokenize
2
+ from hazm import sent_tokenize
3
+ import re
4
+ import six
5
+
6
+ from normalizer import normalize
7
+
8
+ persian_regex = "0-9۰۱۲۳۴۵۶۷۸۹ءآئابتثجحخدذرزسشصضطظعغفقلمنهوپچژکگیە\u200c"
9
+
10
+
11
+ def filter_by_lang_regex(text, ratio=0.7, regex="0-9۰۱۲۳۴۵۶۷۸۹ءآئابتثجحخدذرزسشصضطظعغفقلمنهوپچژکگیە\u200c"):
12
+ candidate_text = re.sub(r"[^" + regex + "]+", " ", six.ensure_str(text)).replace(" ", "")
13
+ text = text.replace(" ", "")
14
+
15
+ return True if (len(candidate_text) / len(text)) > ratio else False
16
+
17
+
18
+ def filter_by_num_tokens(text, gt=64):
19
+ return True if len(word_tokenize(text)) > gt else False
20
+
21
+
22
+ def filter_by_num_sents(text, gt=2):
23
+ return True if len(sent_tokenize(text)) > gt else False
24
+
25
+
26
+ def normalizer(text, do_lowercase=False):
27
+ text = normalize(text)
28
+ if do_lowercase:
29
+ text = text.lower()
30
+
31
+ return text
src/prep_dataset.py CHANGED
@@ -5,7 +5,7 @@ from normalizer import normalize
5
 
6
  class Prep_dataset:
7
 
8
- def __init__(self, subsample=False,*args, **kwargs):
9
  raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_fa")
10
  if subsample:
11
  sample_dataset = raw_dataset.copy()
@@ -16,15 +16,14 @@ class Prep_dataset:
16
  self.raw_dataset = final
17
  else:
18
  self.raw_dataset = raw_dataset
19
-
20
 
21
  def _normalize(self, example):
22
  example["text"] = normalize(example["text"])
23
  return example
24
 
25
  def preprare_dataset(self):
26
- big_dataset = self.raw_dataset.filter(lambda x: len(x["text"])>500)
27
- richSent_dataset = big_dataset.filter(lambda x: len(sent_tokenize(x["text"]))>2)
28
  normalized_dataset = richSent_dataset.map(self._normalize)
29
 
30
- return normalized_dataset
 
5
 
6
  class Prep_dataset:
7
 
8
+ def __init__(self, subsample=False, *args, **kwargs):
9
  raw_dataset = load_dataset("oscar", f"unshuffled_deduplicated_fa")
10
  if subsample:
11
  sample_dataset = raw_dataset.copy()
 
16
  self.raw_dataset = final
17
  else:
18
  self.raw_dataset = raw_dataset
 
19
 
20
  def _normalize(self, example):
21
  example["text"] = normalize(example["text"])
22
  return example
23
 
24
  def preprare_dataset(self):
25
+ big_dataset = self.raw_dataset.filter(lambda x: len(x["text"]) > 500)
26
+ richSent_dataset = big_dataset.filter(lambda x: len(sent_tokenize(x["text"])) > 2)
27
  normalized_dataset = richSent_dataset.map(self._normalize)
28
 
29
+ return normalized_dataset
src/run_clm_flax.py CHANGED
@@ -35,11 +35,13 @@ from datasets import Dataset, load_dataset
35
  from tqdm import tqdm
36
 
37
  import jax
 
38
  import jax.numpy as jnp
39
  import optax
40
  import transformers
41
  from flax import jax_utils, traverse_util
42
  from flax.jax_utils import unreplicate
 
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
  from transformers import (
@@ -54,6 +56,12 @@ from transformers import (
54
  )
55
  from transformers.testing_utils import CaptureLogger
56
 
 
 
 
 
 
 
57
 
58
  logger = logging.getLogger(__name__)
59
 
@@ -72,7 +80,6 @@ else:
72
  "Please run pip install tensorboard to enable."
73
  )
74
 
75
-
76
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
77
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
78
 
@@ -87,7 +94,7 @@ class ModelArguments:
87
  default=None,
88
  metadata={
89
  "help": "The model checkpoint for weights initialization."
90
- "Don't set if you want to train a model from scratch."
91
  },
92
  )
93
  model_type: Optional[str] = field(
@@ -136,14 +143,14 @@ class DataTrainingArguments:
136
  default=None,
137
  metadata={
138
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
139
- "value if set."
140
  },
141
  )
142
  max_eval_samples: Optional[int] = field(
143
  default=None,
144
  metadata={
145
  "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
146
- "value if set."
147
  },
148
  )
149
  overwrite_cache: bool = field(
@@ -159,8 +166,8 @@ class DataTrainingArguments:
159
  default=None,
160
  metadata={
161
  "help": "Optional input sequence length after tokenization. "
162
- "The training dataset will be truncated in block of this size for training. "
163
- "Default to the model max input length for single sentence inputs (take into account special tokens)."
164
  },
165
  )
166
  overwrite_cache: bool = field(
@@ -214,7 +221,19 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
214
  yield batch
215
 
216
 
217
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
 
 
 
 
 
 
 
 
 
 
 
 
218
  summary_writer.scalar("train_time", train_time, step)
219
 
220
  train_metrics = get_metrics(train_metrics)
@@ -223,12 +242,14 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
223
  for i, val in enumerate(vals):
224
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
225
 
 
 
226
  for metric_name, value in eval_metrics.items():
227
  summary_writer.scalar(f"eval_{metric_name}", value, step)
228
 
229
 
230
  def create_learning_rate_fn(
231
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
232
  ) -> Callable[[int], jnp.array]:
233
  """Returns a linear warmup, linear_decay learning rate function."""
234
  steps_per_epoch = train_ds_size // train_batch_size
@@ -255,10 +276,10 @@ def main():
255
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
256
 
257
  if (
258
- os.path.exists(training_args.output_dir)
259
- and os.listdir(training_args.output_dir)
260
- and training_args.do_train
261
- and not training_args.overwrite_output_dir
262
  ):
263
  raise ValueError(
264
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
@@ -283,6 +304,9 @@ def main():
283
  # Set the verbosity to info of the Transformers logger (on main process only):
284
  logger.info(f"Training/evaluation parameters {training_args}")
285
 
 
 
 
286
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
287
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
288
  # (the dataset will be downloaded automatically from the datasets Hub).
@@ -294,18 +318,21 @@ def main():
294
  # download the dataset.
295
  if data_args.dataset_name is not None:
296
  # Downloading and loading a dataset from the hub.
297
- dataset = load_dataset(
298
- data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
 
 
 
299
  )
300
 
301
- if "validation" not in dataset.keys():
302
- dataset["validation"] = load_dataset(
303
  data_args.dataset_name,
304
  data_args.dataset_config_name,
305
  split=f"train[:{data_args.validation_split_percentage}%]",
306
  cache_dir=model_args.cache_dir,
307
  )
308
- dataset["train"] = load_dataset(
309
  data_args.dataset_name,
310
  data_args.dataset_config_name,
311
  split=f"train[{data_args.validation_split_percentage}%:]",
@@ -320,9 +347,22 @@ def main():
320
  extension = data_args.train_file.split(".")[-1]
321
  if extension == "txt":
322
  extension = "text"
323
- dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
 
 
 
 
 
 
324
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
325
  # https://huggingface.co/docs/datasets/loading_datasets.html.
 
 
 
 
 
 
326
 
327
  # Load pretrained model and tokenizer
328
 
@@ -339,11 +379,15 @@ def main():
339
 
340
  if model_args.tokenizer_name:
341
  tokenizer = AutoTokenizer.from_pretrained(
342
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
343
  )
344
  elif model_args.model_name_or_path:
345
  tokenizer = AutoTokenizer.from_pretrained(
346
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
347
  )
348
  else:
349
  raise ValueError(
@@ -353,11 +397,16 @@ def main():
353
 
354
  if model_args.model_name_or_path:
355
  model = FlaxAutoModelForCausalLM.from_pretrained(
356
- model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
 
 
357
  )
358
  else:
359
  model = FlaxAutoModelForCausalLM.from_config(
360
- config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
 
361
  )
362
 
363
  # Preprocessing the datasets.
@@ -415,7 +464,7 @@ def main():
415
  total_length = (total_length // block_size) * block_size
416
  # Split by chunks of max_len.
417
  result = {
418
- k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
419
  for k, t in concatenated_examples.items()
420
  }
421
  result["labels"] = result["input_ids"].copy()
@@ -554,6 +603,7 @@ def main():
554
  logger.info(f" Total optimization steps = {total_train_steps}")
555
 
556
  train_time = 0
 
557
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
558
  for epoch in epochs:
559
  # ======================== Training ================================
@@ -561,24 +611,30 @@ def main():
561
 
562
  # Create sampling rng
563
  rng, input_rng = jax.random.split(rng)
564
- train_metrics = []
565
 
566
  # Generate an epoch by shuffling sampling indices from the train dataset
567
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
568
  steps_per_epoch = len(train_dataset) // train_batch_size
569
  # train
570
- for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
571
  batch = next(train_loader)
572
  state, train_metric = p_train_step(state, batch)
573
  train_metrics.append(train_metric)
574
 
575
- train_time += time.time() - train_start
576
 
577
- train_metric = unreplicate(train_metric)
 
 
 
 
 
578
 
579
- epochs.write(
580
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
581
- )
 
 
582
 
583
  # ======================== Evaluating ==============================
584
  eval_metrics = []
@@ -608,7 +664,7 @@ def main():
608
  # Save metrics
609
  if has_tensorboard and jax.process_index() == 0:
610
  cur_step = epoch * (len(train_dataset) // train_batch_size)
611
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
612
 
613
  # save checkpoint after each epoch and push checkpoint to the hub
614
  if jax.process_index() == 0:
@@ -617,7 +673,7 @@ def main():
617
  training_args.output_dir,
618
  params=params,
619
  push_to_hub=training_args.push_to_hub,
620
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
621
  )
622
 
623
 
 
35
  from tqdm import tqdm
36
 
37
  import jax
38
+ from jax import lax
39
  import jax.numpy as jnp
40
  import optax
41
  import transformers
42
  from flax import jax_utils, traverse_util
43
  from flax.jax_utils import unreplicate
44
+ from flax.training import checkpoints
45
  from flax.training import train_state
46
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
  from transformers import (
 
56
  )
57
  from transformers.testing_utils import CaptureLogger
58
 
59
+ from data_utils import (
60
+ filter_by_lang_regex,
61
+ filter_by_num_tokens,
62
+ filter_by_num_sents,
63
+ normalizer
64
+ )
65
 
66
  logger = logging.getLogger(__name__)
67
 
 
80
  "Please run pip install tensorboard to enable."
81
  )
82
 
 
83
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
84
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
85
 
 
94
  default=None,
95
  metadata={
96
  "help": "The model checkpoint for weights initialization."
97
+ "Don't set if you want to train a model from scratch."
98
  },
99
  )
100
  model_type: Optional[str] = field(
 
143
  default=None,
144
  metadata={
145
  "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
146
+ "value if set."
147
  },
148
  )
149
  max_eval_samples: Optional[int] = field(
150
  default=None,
151
  metadata={
152
  "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
153
+ "value if set."
154
  },
155
  )
156
  overwrite_cache: bool = field(
 
166
  default=None,
167
  metadata={
168
  "help": "Optional input sequence length after tokenization. "
169
+ "The training dataset will be truncated in block of this size for training. "
170
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
171
  },
172
  )
173
  overwrite_cache: bool = field(
 
221
  yield batch
222
 
223
 
224
+ # def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
225
+ # summary_writer.scalar("train_time", train_time, step)
226
+ #
227
+ # train_metrics = get_metrics(train_metrics)
228
+ # for key, vals in train_metrics.items():
229
+ # tag = f"train_{key}"
230
+ # for i, val in enumerate(vals):
231
+ # summary_writer.scalar(tag, val, step - len(vals) + i + 1)
232
+ #
233
+ # for metric_name, value in eval_metrics.items():
234
+ # summary_writer.scalar(f"eval_{metric_name}", value, step)
235
+
236
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
237
  summary_writer.scalar("train_time", train_time, step)
238
 
239
  train_metrics = get_metrics(train_metrics)
 
242
  for i, val in enumerate(vals):
243
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
244
 
245
+
246
+ def write_eval_metric(summary_writer, eval_metrics, step):
247
  for metric_name, value in eval_metrics.items():
248
  summary_writer.scalar(f"eval_{metric_name}", value, step)
249
 
250
 
251
  def create_learning_rate_fn(
252
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
253
  ) -> Callable[[int], jnp.array]:
254
  """Returns a linear warmup, linear_decay learning rate function."""
255
  steps_per_epoch = train_ds_size // train_batch_size
 
276
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
277
 
278
  if (
279
+ os.path.exists(training_args.output_dir)
280
+ and os.listdir(training_args.output_dir)
281
+ and training_args.do_train
282
+ and not training_args.overwrite_output_dir
283
  ):
284
  raise ValueError(
285
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
 
304
  # Set the verbosity to info of the Transformers logger (on main process only):
305
  logger.info(f"Training/evaluation parameters {training_args}")
306
 
307
+ checkpoints_dir = os.path.join(training_args.output_dir, "checkpoints")
308
+ os.makedirs(checkpoints_dir, exist_ok=True)
309
+
310
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
311
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
312
  # (the dataset will be downloaded automatically from the datasets Hub).
 
318
  # download the dataset.
319
  if data_args.dataset_name is not None:
320
  # Downloading and loading a dataset from the hub.
321
+ raw_dataset = load_dataset(
322
+ data_args.dataset_name,
323
+ data_args.dataset_config_name,
324
+ cache_dir=model_args.cache_dir,
325
+ keep_in_memory=False
326
  )
327
 
328
+ if "validation" not in raw_dataset.keys():
329
+ raw_dataset["validation"] = load_dataset(
330
  data_args.dataset_name,
331
  data_args.dataset_config_name,
332
  split=f"train[:{data_args.validation_split_percentage}%]",
333
  cache_dir=model_args.cache_dir,
334
  )
335
+ raw_dataset["train"] = load_dataset(
336
  data_args.dataset_name,
337
  data_args.dataset_config_name,
338
  split=f"train[{data_args.validation_split_percentage}%:]",
 
347
  extension = data_args.train_file.split(".")[-1]
348
  if extension == "txt":
349
  extension = "text"
350
+
351
+ raw_dataset = load_dataset(
352
+ extension,
353
+ data_files=data_files,
354
+ delimiter="\t",
355
+ cache_dir=model_args.cache_dir
356
+ )
357
+
358
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
359
  # https://huggingface.co/docs/datasets/loading_datasets.html.
360
+ logger.info("Preprocessing the dataset")
361
+ dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
362
+ dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=128))
363
+ dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
364
+ dataset = dataset.map(normalizer)
365
+ logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
366
 
367
  # Load pretrained model and tokenizer
368
 
 
379
 
380
  if model_args.tokenizer_name:
381
  tokenizer = AutoTokenizer.from_pretrained(
382
+ model_args.tokenizer_name,
383
+ cache_dir=model_args.cache_dir,
384
+ use_fast=model_args.use_fast_tokenizer
385
  )
386
  elif model_args.model_name_or_path:
387
  tokenizer = AutoTokenizer.from_pretrained(
388
+ model_args.model_name_or_path,
389
+ cache_dir=model_args.cache_dir,
390
+ use_fast=model_args.use_fast_tokenizer
391
  )
392
  else:
393
  raise ValueError(
 
397
 
398
  if model_args.model_name_or_path:
399
  model = FlaxAutoModelForCausalLM.from_pretrained(
400
+ model_args.model_name_or_path,
401
+ config=config,
402
+ seed=training_args.seed,
403
+ dtype=getattr(jnp, model_args.dtype)
404
  )
405
  else:
406
  model = FlaxAutoModelForCausalLM.from_config(
407
+ config,
408
+ seed=training_args.seed,
409
+ dtype=getattr(jnp, model_args.dtype)
410
  )
411
 
412
  # Preprocessing the datasets.
 
464
  total_length = (total_length // block_size) * block_size
465
  # Split by chunks of max_len.
466
  result = {
467
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
468
  for k, t in concatenated_examples.items()
469
  }
470
  result["labels"] = result["input_ids"].copy()
 
603
  logger.info(f" Total optimization steps = {total_train_steps}")
604
 
605
  train_time = 0
606
+ train_metrics = []
607
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
608
  for epoch in epochs:
609
  # ======================== Training ================================
 
611
 
612
  # Create sampling rng
613
  rng, input_rng = jax.random.split(rng)
 
614
 
615
  # Generate an epoch by shuffling sampling indices from the train dataset
616
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
617
  steps_per_epoch = len(train_dataset) // train_batch_size
618
  # train
619
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
620
  batch = next(train_loader)
621
  state, train_metric = p_train_step(state, batch)
622
  train_metrics.append(train_metric)
623
 
624
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
625
 
626
+ if cur_step % training_args.logging_steps and cur_step > 0:
627
+ # Save metrics
628
+ train_metric = unreplicate(train_metric)
629
+ train_time += time.time() - train_start
630
+ if has_tensorboard and jax.process_index() == 0:
631
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
632
 
633
+ epochs.write(
634
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
635
+ )
636
+
637
+ train_metrics = []
638
 
639
  # ======================== Evaluating ==============================
640
  eval_metrics = []
 
664
  # Save metrics
665
  if has_tensorboard and jax.process_index() == 0:
666
  cur_step = epoch * (len(train_dataset) // train_batch_size)
667
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
668
 
669
  # save checkpoint after each epoch and push checkpoint to the hub
670
  if jax.process_index() == 0:
 
673
  training_args.output_dir,
674
  params=params,
675
  push_to_hub=training_args.push_to_hub,
676
+ commit_message=f"Saving weights and logs of epoch {epoch + 1}",
677
  )
678
 
679
 
src/run_clm_flax_with_ckpts.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=causal-lm
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Callable, Optional
32
+
33
+ import datasets
34
+ from datasets import Dataset, load_dataset
35
+ from tqdm import tqdm
36
+
37
+ import jax
38
+ from jax import lax
39
+ import jax.numpy as jnp
40
+ import optax
41
+ import transformers
42
+ from flax import jax_utils, traverse_util
43
+ from flax.jax_utils import unreplicate
44
+ from flax.training import checkpoints
45
+ from flax.training import train_state
46
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
+ from transformers import (
48
+ CONFIG_MAPPING,
49
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
50
+ AutoConfig,
51
+ AutoTokenizer,
52
+ FlaxAutoModelForCausalLM,
53
+ HfArgumentParser,
54
+ TrainingArguments,
55
+ is_tensorboard_available,
56
+ )
57
+ from transformers.testing_utils import CaptureLogger
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ # Cache the result
62
+ has_tensorboard = is_tensorboard_available()
63
+ if has_tensorboard:
64
+ try:
65
+ from flax.metrics.tensorboard import SummaryWriter
66
+ except ImportError as ie:
67
+ has_tensorboard = False
68
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
69
+
70
+ else:
71
+ print(
72
+ "Unable to display metrics through TensorBoard because the package is not installed: "
73
+ "Please run pip install tensorboard to enable."
74
+ )
75
+
76
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
77
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
78
+
79
+
80
+ @dataclass
81
+ class ModelArguments:
82
+ """
83
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
84
+ """
85
+
86
+ model_name_or_path: Optional[str] = field(
87
+ default=None,
88
+ metadata={
89
+ "help": "The model checkpoint for weights initialization."
90
+ "Don't set if you want to train a model from scratch."
91
+ },
92
+ )
93
+ model_type: Optional[str] = field(
94
+ default=None,
95
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
96
+ )
97
+ config_name: Optional[str] = field(
98
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
99
+ )
100
+ tokenizer_name: Optional[str] = field(
101
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
102
+ )
103
+ cache_dir: Optional[str] = field(
104
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
105
+ )
106
+ use_fast_tokenizer: bool = field(
107
+ default=True,
108
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
109
+ )
110
+ dtype: Optional[str] = field(
111
+ default="float32",
112
+ metadata={
113
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
114
+ },
115
+ )
116
+
117
+
118
+ @dataclass
119
+ class DataTrainingArguments:
120
+ """
121
+ Arguments pertaining to what data we are going to input our model for training and eval.
122
+ """
123
+
124
+ dataset_name: Optional[str] = field(
125
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
126
+ )
127
+ dataset_config_name: Optional[str] = field(
128
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
129
+ )
130
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
131
+ validation_file: Optional[str] = field(
132
+ default=None,
133
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
134
+ )
135
+ max_train_samples: Optional[int] = field(
136
+ default=None,
137
+ metadata={
138
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
139
+ "value if set."
140
+ },
141
+ )
142
+ max_eval_samples: Optional[int] = field(
143
+ default=None,
144
+ metadata={
145
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
146
+ "value if set."
147
+ },
148
+ )
149
+ overwrite_cache: bool = field(
150
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
151
+ )
152
+ validation_split_percentage: Optional[int] = field(
153
+ default=5,
154
+ metadata={
155
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
156
+ },
157
+ )
158
+ block_size: Optional[int] = field(
159
+ default=None,
160
+ metadata={
161
+ "help": "Optional input sequence length after tokenization. "
162
+ "The training dataset will be truncated in block of this size for training. "
163
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
164
+ },
165
+ )
166
+ overwrite_cache: bool = field(
167
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
168
+ )
169
+ preprocessing_num_workers: Optional[int] = field(
170
+ default=None,
171
+ metadata={"help": "The number of processes to use for the preprocessing."},
172
+ )
173
+
174
+ def __post_init__(self):
175
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
176
+ raise ValueError("Need either a dataset name or a training/validation file.")
177
+ else:
178
+ if self.train_file is not None:
179
+ extension = self.train_file.split(".")[-1]
180
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
181
+ if self.validation_file is not None:
182
+ extension = self.validation_file.split(".")[-1]
183
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
184
+
185
+
186
+ class TrainState(train_state.TrainState):
187
+ dropout_rng: jnp.ndarray
188
+
189
+ def replicate(self):
190
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
191
+
192
+
193
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
194
+ """
195
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
196
+ Shuffle batches if `shuffle` is `True`.
197
+ """
198
+ steps_per_epoch = len(dataset) // batch_size
199
+
200
+ if shuffle:
201
+ batch_idx = jax.random.permutation(rng, len(dataset))
202
+ else:
203
+ batch_idx = jnp.arange(len(dataset))
204
+
205
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
206
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
207
+
208
+ for idx in batch_idx:
209
+ batch = dataset[idx]
210
+ batch = {k: jnp.array(v) for k, v in batch.items()}
211
+
212
+ batch = shard(batch)
213
+
214
+ yield batch
215
+
216
+
217
+ # def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
218
+ # summary_writer.scalar("train_time", train_time, step)
219
+ #
220
+ # train_metrics = get_metrics(train_metrics)
221
+ # for key, vals in train_metrics.items():
222
+ # tag = f"train_{key}"
223
+ # for i, val in enumerate(vals):
224
+ # summary_writer.scalar(tag, val, step - len(vals) + i + 1)
225
+ #
226
+ # for metric_name, value in eval_metrics.items():
227
+ # summary_writer.scalar(f"eval_{metric_name}", value, step)
228
+
229
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
230
+ summary_writer.scalar("train_time", train_time, step)
231
+
232
+ train_metrics = get_metrics(train_metrics)
233
+ for key, vals in train_metrics.items():
234
+ tag = f"train_{key}"
235
+ for i, val in enumerate(vals):
236
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
237
+
238
+
239
+ def write_eval_metric(summary_writer, eval_metrics, step):
240
+ for metric_name, value in eval_metrics.items():
241
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
242
+
243
+
244
+ def create_learning_rate_fn(
245
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
246
+ ) -> Callable[[int], jnp.array]:
247
+ """Returns a linear warmup, linear_decay learning rate function."""
248
+ steps_per_epoch = train_ds_size // train_batch_size
249
+ num_train_steps = steps_per_epoch * num_train_epochs
250
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
251
+ decay_fn = optax.linear_schedule(
252
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
253
+ )
254
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
255
+ return schedule_fn
256
+
257
+
258
+ def restore_checkpoint(state, workdir):
259
+ return checkpoints.restore_checkpoint(workdir, state)
260
+
261
+
262
+ def save_checkpoint(state, workdir):
263
+ if jax.process_index() == 0:
264
+ # get train state from the first replica
265
+ state = jax.device_get(jax.tree_map(lambda x: x[0], state))
266
+ step = int(state.step)
267
+ checkpoints.save_checkpoint(workdir, state, step, keep=3)
268
+
269
+
270
+ def main():
271
+ # See all possible arguments in src/transformers/training_args.py
272
+ # or by passing the --help flag to this script.
273
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
274
+
275
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
276
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
277
+ # If we pass only one argument to the script and it's the path to a json file,
278
+ # let's parse it to get our arguments.
279
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
280
+ else:
281
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
282
+
283
+ if (
284
+ os.path.exists(training_args.output_dir)
285
+ and os.listdir(training_args.output_dir)
286
+ and training_args.do_train
287
+ and not training_args.overwrite_output_dir
288
+ ):
289
+ raise ValueError(
290
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
291
+ "Use --overwrite_output_dir to overcome."
292
+ )
293
+
294
+ # Make one log on every process with the configuration for debugging.
295
+ logging.basicConfig(
296
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
297
+ datefmt="%m/%d/%Y %H:%M:%S",
298
+ level=logging.INFO,
299
+ )
300
+ # Setup logging, we only want one process per machine to log things on the screen.
301
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
302
+ if jax.process_index() == 0:
303
+ datasets.utils.logging.set_verbosity_warning()
304
+ transformers.utils.logging.set_verbosity_info()
305
+ else:
306
+ datasets.utils.logging.set_verbosity_error()
307
+ transformers.utils.logging.set_verbosity_error()
308
+
309
+ # Set the verbosity to info of the Transformers logger (on main process only):
310
+ logger.info(f"Training/evaluation parameters {training_args}")
311
+
312
+ checkpoints_dir = os.path.join(training_args.output_dir, "checkpoints")
313
+ os.makedirs(checkpoints_dir, exist_ok=True)
314
+
315
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
316
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
317
+ # (the dataset will be downloaded automatically from the datasets Hub).
318
+ #
319
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
320
+ # 'text' is found. You can easily tweak this behavior (see below).
321
+ #
322
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
323
+ # download the dataset.
324
+ if data_args.dataset_name is not None:
325
+ # Downloading and loading a dataset from the hub.
326
+ dataset = load_dataset(
327
+ data_args.dataset_name,
328
+ data_args.dataset_config_name,
329
+ cache_dir=model_args.cache_dir,
330
+ keep_in_memory=False
331
+ )
332
+
333
+ if "validation" not in dataset.keys():
334
+ dataset["validation"] = load_dataset(
335
+ data_args.dataset_name,
336
+ data_args.dataset_config_name,
337
+ split=f"train[:{data_args.validation_split_percentage}%]",
338
+ cache_dir=model_args.cache_dir,
339
+ )
340
+ dataset["train"] = load_dataset(
341
+ data_args.dataset_name,
342
+ data_args.dataset_config_name,
343
+ split=f"train[{data_args.validation_split_percentage}%:]",
344
+ cache_dir=model_args.cache_dir,
345
+ )
346
+ else:
347
+ data_files = {}
348
+ if data_args.train_file is not None:
349
+ data_files["train"] = data_args.train_file
350
+ if data_args.validation_file is not None:
351
+ data_files["validation"] = data_args.validation_file
352
+ extension = data_args.train_file.split(".")[-1]
353
+ if extension == "txt":
354
+ extension = "text"
355
+
356
+ dataset = load_dataset(
357
+ extension,
358
+ data_files=data_files,
359
+ delimiter="\t",
360
+ cache_dir=model_args.cache_dir
361
+ )
362
+
363
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
364
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
365
+
366
+ # Load pretrained model and tokenizer
367
+
368
+ # Distributed training:
369
+ # The .from_pretrained methods guarantee that only one local process can concurrently
370
+ # download model & vocab.
371
+ if model_args.config_name:
372
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
373
+ elif model_args.model_name_or_path:
374
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
375
+ else:
376
+ config = CONFIG_MAPPING[model_args.model_type]()
377
+ logger.warning("You are instantiating a new config instance from scratch.")
378
+
379
+ if model_args.tokenizer_name:
380
+ tokenizer = AutoTokenizer.from_pretrained(
381
+ model_args.tokenizer_name,
382
+ cache_dir=model_args.cache_dir,
383
+ use_fast=model_args.use_fast_tokenizer
384
+ )
385
+ elif model_args.model_name_or_path:
386
+ tokenizer = AutoTokenizer.from_pretrained(
387
+ model_args.model_name_or_path,
388
+ cache_dir=model_args.cache_dir,
389
+ use_fast=model_args.use_fast_tokenizer
390
+ )
391
+ else:
392
+ raise ValueError(
393
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
394
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
395
+ )
396
+
397
+ if model_args.model_name_or_path:
398
+ model = FlaxAutoModelForCausalLM.from_pretrained(
399
+ model_args.model_name_or_path,
400
+ config=config,
401
+ seed=training_args.seed,
402
+ dtype=getattr(jnp, model_args.dtype)
403
+ )
404
+ else:
405
+ model = FlaxAutoModelForCausalLM.from_config(
406
+ config,
407
+ seed=training_args.seed,
408
+ dtype=getattr(jnp, model_args.dtype)
409
+ )
410
+
411
+ # Preprocessing the datasets.
412
+ # First we tokenize all the texts.
413
+ if training_args.do_train:
414
+ column_names = dataset["train"].column_names
415
+ else:
416
+ column_names = dataset["validation"].column_names
417
+ text_column_name = "text" if "text" in column_names else column_names[0]
418
+
419
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
420
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
421
+
422
+ def tokenize_function(examples):
423
+ with CaptureLogger(tok_logger) as cl:
424
+ output = tokenizer(examples[text_column_name])
425
+ # clm input could be much much longer than block_size
426
+ if "Token indices sequence length is longer than the" in cl.out:
427
+ tok_logger.warning(
428
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
429
+ )
430
+ return output
431
+
432
+ tokenized_datasets = dataset.map(
433
+ tokenize_function,
434
+ batched=True,
435
+ num_proc=data_args.preprocessing_num_workers,
436
+ remove_columns=column_names,
437
+ load_from_cache_file=not data_args.overwrite_cache,
438
+ )
439
+
440
+ if data_args.block_size is None:
441
+ block_size = tokenizer.model_max_length
442
+ if block_size > config.max_position_embeddings:
443
+ logger.warning(
444
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
445
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
446
+ )
447
+ block_size = 1024
448
+ else:
449
+ if data_args.block_size > tokenizer.model_max_length:
450
+ logger.warning(
451
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
452
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
453
+ )
454
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
455
+
456
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
457
+ def group_texts(examples):
458
+ # Concatenate all texts.
459
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
460
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
461
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
462
+ # customize this part to your needs.
463
+ total_length = (total_length // block_size) * block_size
464
+ # Split by chunks of max_len.
465
+ result = {
466
+ k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
467
+ for k, t in concatenated_examples.items()
468
+ }
469
+ result["labels"] = result["input_ids"].copy()
470
+ return result
471
+
472
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
473
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
474
+ # to preprocess.
475
+ #
476
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
477
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
478
+
479
+ lm_datasets = tokenized_datasets.map(
480
+ group_texts,
481
+ batched=True,
482
+ num_proc=data_args.preprocessing_num_workers,
483
+ load_from_cache_file=not data_args.overwrite_cache,
484
+ )
485
+
486
+ if training_args.do_train:
487
+ if "train" not in tokenized_datasets:
488
+ raise ValueError("--do_train requires a train dataset")
489
+ train_dataset = lm_datasets["train"]
490
+ if data_args.max_train_samples is not None:
491
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
492
+
493
+ if training_args.do_eval:
494
+ if "validation" not in tokenized_datasets:
495
+ raise ValueError("--do_eval requires a validation dataset")
496
+ eval_dataset = lm_datasets["validation"]
497
+ if data_args.max_eval_samples is not None:
498
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
499
+
500
+ # Enable tensorboard only on the master node
501
+ if has_tensorboard and jax.process_index() == 0:
502
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
503
+
504
+ # Initialize our training
505
+ rng = jax.random.PRNGKey(training_args.seed)
506
+ rng, dropout_rng = jax.random.split(rng)
507
+
508
+ # Store some constant
509
+ num_epochs = int(training_args.num_train_epochs)
510
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
511
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
512
+ steps_per_epoch = len(train_dataset) // train_batch_size
513
+
514
+ # total_train_steps = steps_per_epoch * num_epochs
515
+ if training_args.max_steps == -1:
516
+ total_train_steps = steps_per_epoch * num_epochs
517
+ else:
518
+ total_train_steps = training_args.max_steps
519
+
520
+ # Create learning rate schedule
521
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
522
+ len(train_dataset),
523
+ train_batch_size,
524
+ training_args.num_train_epochs,
525
+ training_args.warmup_steps,
526
+ training_args.learning_rate,
527
+ )
528
+
529
+ # We use Optax's "masking" functionality to not apply weight decay
530
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
531
+ # mask boolean with the same structure as the parameters.
532
+ # The mask is True for parameters that should be decayed.
533
+ # Note that this mask is specifically adapted for FlaxGPT2.
534
+ # For other models, one should correct the layer norm parameter naming
535
+ # accordingly.
536
+ def decay_mask_fn(params):
537
+ flat_params = traverse_util.flatten_dict(params)
538
+ flat_mask = {
539
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
540
+ for path in flat_params
541
+ }
542
+ return traverse_util.unflatten_dict(flat_mask)
543
+
544
+ # create adam optimizer
545
+ adamw = optax.adamw(
546
+ learning_rate=linear_decay_lr_schedule_fn,
547
+ b1=training_args.adam_beta1,
548
+ b2=training_args.adam_beta2,
549
+ eps=training_args.adam_epsilon,
550
+ weight_decay=training_args.weight_decay,
551
+ mask=decay_mask_fn,
552
+ )
553
+
554
+ # Setup train state
555
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
556
+
557
+ # Restore states
558
+ state = restore_checkpoint(state, checkpoints_dir)
559
+ step_offset = int(state.step) # step_offset > 0 if restarting from checkpoint
560
+ epoch_offset = int(num_epochs - ((total_train_steps - step_offset) / steps_per_epoch))
561
+
562
+ def loss_fn(logits, labels):
563
+ shift_logits = logits[..., :-1, :]
564
+ shift_labels = labels[..., 1:]
565
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
566
+ return loss.mean()
567
+
568
+ # Define gradient update step fn
569
+ def train_step(state, batch):
570
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
571
+
572
+ def compute_loss(params):
573
+ labels = batch.pop("labels")
574
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
575
+ loss = loss_fn(logits, labels)
576
+ return loss
577
+
578
+ grad_fn = jax.value_and_grad(compute_loss)
579
+ loss, grad = grad_fn(state.params)
580
+ grad = jax.lax.pmean(grad, "batch")
581
+
582
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
583
+
584
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
585
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
586
+
587
+ return new_state, metrics
588
+
589
+ # Define eval fn
590
+ def eval_step(params, batch):
591
+ labels = batch.pop("labels")
592
+ logits = model(**batch, params=params, train=False)[0]
593
+ loss = loss_fn(logits, labels)
594
+
595
+ # summarize metrics
596
+ metrics = {"loss": loss}
597
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
598
+ return metrics
599
+
600
+ # Create parallel version of the train and eval step
601
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
602
+ p_eval_step = jax.pmap(eval_step, "batch")
603
+
604
+ # Replicate the train state on each device
605
+ state = state.replicate()
606
+
607
+ logger.info("***** Running training *****")
608
+ logger.info(f" Num examples = {len(train_dataset)}")
609
+ logger.info(f" Num Epochs = {num_epochs}")
610
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
611
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
612
+ logger.info(f" Total optimization steps = {total_train_steps}")
613
+
614
+ if step_offset > 0:
615
+ logger.info(" Continuing training from checkpoint")
616
+ logger.info(f" Continuing training from epoch {epoch_offset}")
617
+ logger.info(f" Continuing training from global step {step_offset}")
618
+
619
+ train_time = 0
620
+ train_metrics = []
621
+ epochs = tqdm(range(epoch_offset, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
622
+ for epoch in epochs:
623
+ # ======================== Training ================================
624
+ train_start = time.time()
625
+
626
+ # Create sampling rng
627
+ rng, input_rng = jax.random.split(rng)
628
+
629
+ # Generate an epoch by shuffling sampling indices from the train dataset
630
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
631
+ steps_per_epoch = len(train_dataset) // train_batch_size
632
+ num_steps = abs(step_offset - (steps_per_epoch * (epoch + 1)))
633
+
634
+ # train
635
+ for step in tqdm(range(num_steps), desc="Training...", position=1, leave=False):
636
+ batch = next(train_loader)
637
+ state, train_metric = p_train_step(state, batch)
638
+ train_metrics.append(train_metric)
639
+
640
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
641
+
642
+ if cur_step % training_args.logging_steps and cur_step > 0:
643
+ # Save metrics
644
+ train_metric = unreplicate(train_metric)
645
+ train_time += time.time() - train_start
646
+ if has_tensorboard and jax.process_index() == 0:
647
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
648
+
649
+ epochs.write(
650
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
651
+ )
652
+
653
+ train_metrics = []
654
+
655
+ if cur_step % training_args.save_steps and cur_step > 0:
656
+ save_checkpoint(state, checkpoints_dir)
657
+
658
+ # ======================== Evaluating ==============================
659
+ eval_metrics = []
660
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
661
+ eval_steps = len(eval_dataset) // eval_batch_size
662
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
663
+ # Model forward
664
+ batch = next(eval_loader)
665
+ metrics = p_eval_step(state.params, batch)
666
+ eval_metrics.append(metrics)
667
+
668
+ # normalize eval metrics
669
+ eval_metrics = get_metrics(eval_metrics)
670
+
671
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
672
+
673
+ try:
674
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
675
+ except OverflowError:
676
+ eval_metrics["perplexity"] = float("inf")
677
+
678
+ # Print metrics and update progress bar
679
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
680
+ epochs.write(desc)
681
+ epochs.desc = desc
682
+
683
+ # Save metrics
684
+ if has_tensorboard and jax.process_index() == 0:
685
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
686
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
687
+
688
+ # save checkpoint after each epoch and push checkpoint to the hub
689
+ if jax.process_index() == 0:
690
+ params = jax.device_get(unreplicate(state.params))
691
+ model.save_pretrained(
692
+ training_args.output_dir,
693
+ params=params,
694
+ push_to_hub=training_args.push_to_hub,
695
+ commit_message=f"Saving weights and logs of epoch {epoch + 1}",
696
+ )
697
+
698
+
699
+ if __name__ == "__main__":
700
+ main()
src/train_tokenizer.py CHANGED
@@ -11,6 +11,13 @@ from transformers import (
11
  HfArgumentParser,
12
  )
13
 
 
 
 
 
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
@@ -89,11 +96,11 @@ def main():
89
  logger.info(f"Training tokenizer")
90
 
91
  if tokenizer_args.dataset_name is not None:
92
- dataset = load_dataset(
93
  tokenizer_args.dataset_name,
94
  tokenizer_args.dataset_config_name,
95
  cache_dir=tokenizer_args.cache_dir,
96
- split="train"
97
  )
98
  else:
99
  data_files = {"train": tokenizer_args.train_file}
@@ -101,13 +108,20 @@ def main():
101
  if extension == "txt":
102
  extension = "text"
103
 
104
- dataset = load_dataset(
105
  extension,
106
  data_files=data_files,
107
  delimiter="\t",
108
  cache_dir=tokenizer_args.cache_dir,
109
  )
110
 
 
 
 
 
 
 
 
111
  tokenizer = ByteLevelBPETokenizer()
112
 
113
  def batch_iterative(batch_size=1000):
@@ -122,7 +136,7 @@ def main():
122
  show_progress=tokenizer_args.show_progress,
123
  )
124
 
125
- logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}/tokenizer")
126
  os.makedirs(tokenizer_args.output_dir, exist_ok=True)
127
  tokenizer.save_model(tokenizer_args.output_dir)
128
  tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)
 
11
  HfArgumentParser,
12
  )
13
 
14
+ from data_utils import (
15
+ filter_by_lang_regex,
16
+ filter_by_num_tokens,
17
+ filter_by_num_sents,
18
+ normalizer
19
+ )
20
+
21
  logger = logging.getLogger(__name__)
22
 
23
 
 
96
  logger.info(f"Training tokenizer")
97
 
98
  if tokenizer_args.dataset_name is not None:
99
+ raw_dataset = load_dataset(
100
  tokenizer_args.dataset_name,
101
  tokenizer_args.dataset_config_name,
102
  cache_dir=tokenizer_args.cache_dir,
103
+ split="train[:10%]"
104
  )
105
  else:
106
  data_files = {"train": tokenizer_args.train_file}
 
108
  if extension == "txt":
109
  extension = "text"
110
 
111
+ raw_dataset = load_dataset(
112
  extension,
113
  data_files=data_files,
114
  delimiter="\t",
115
  cache_dir=tokenizer_args.cache_dir,
116
  )
117
 
118
+ logger.info("Preprocessing the dataset")
119
+ dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
120
+ dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
121
+ dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
122
+ dataset = dataset.map(normalizer)
123
+ logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
124
+
125
  tokenizer = ByteLevelBPETokenizer()
126
 
127
  def batch_iterative(batch_size=1000):
 
136
  show_progress=tokenizer_args.show_progress,
137
  )
138
 
139
+ logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}")
140
  os.makedirs(tokenizer_args.output_dir, exist_ok=True)
141
  tokenizer.save_model(tokenizer_args.output_dir)
142
  tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)