Pedro Cuenca commited on
Commit
77657e6
·
unverified ·
2 Parent(s): a09ea25 a37cd75

Merge pull request #71 from borisdayma/fix-opt_state

Browse files

seq2seq: resume from checkpoint, streaming, text normalization, cache limiting.

dev/seq2seq/do_big_run.sh CHANGED
@@ -1,12 +1,16 @@
1
  python run_seq2seq_flax.py \
2
- --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \
 
 
 
 
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
8
  --preprocessing_num_workers 80 \
9
- --warmup_steps 250 \
10
  --gradient_accumulation_steps 8 \
11
  --do_train \
12
  --do_eval \
 
1
  python run_seq2seq_flax.py \
2
+ --dataset_repo_or_path dalle-mini/encoded \
3
+ --train_file **/train/*/*.jsonl \
4
+ --validation_file **/valid/*/*.jsonl \
5
+ --len_train 42684248 \
6
+ --len_eval 34328 \
7
+ --streaming \
8
+ --normalize_text \
9
  --output_dir output \
10
  --per_device_train_batch_size 56 \
11
  --per_device_eval_batch_size 56 \
12
  --preprocessing_num_workers 80 \
13
+ --warmup_steps 500 \
14
  --gradient_accumulation_steps 8 \
15
  --do_train \
16
  --do_eval \
dev/seq2seq/do_small_run.sh CHANGED
@@ -1,7 +1,10 @@
1
  python run_seq2seq_flax.py \
2
- --max_source_length 128 \
3
- --train_file /data/CC12M/encoded-small-train.tsv \
4
- --validation_file /data/CC12M/encoded-small-valid.tsv \
 
 
 
5
  --output_dir output \
6
  --per_device_train_batch_size 56 \
7
  --per_device_eval_batch_size 56 \
@@ -12,5 +15,5 @@ python run_seq2seq_flax.py \
12
  --do_eval \
13
  --adafactor \
14
  --num_train_epochs 1 \
15
- --max_train_samples 20000 \
16
  --learning_rate 0.005
 
1
  python run_seq2seq_flax.py \
2
+ --dataset_repo_or_path dalle-mini/encoded \
3
+ --train_file **/train/*/*.jsonl \
4
+ --validation_file **/valid/*/*.jsonl \
5
+ --len_train 42684248 \
6
+ --len_eval 34328 \
7
+ --streaming \
8
  --output_dir output \
9
  --per_device_train_batch_size 56 \
10
  --per_device_eval_batch_size 56 \
 
15
  --do_eval \
16
  --adafactor \
17
  --num_train_epochs 1 \
18
+ --max_train_samples 10000 \
19
  --learning_rate 0.005
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -20,13 +20,8 @@ Script adapted from run_summarization_flax.py
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
23
- # set a common huggingface cache folder (used with datasets and transformers) and wandb cache folder (used with artifacts)
24
- os.environ['HF_HOME'] = '/data/huggingface/' # required before importing transformers & datasets
25
- os.environ['WANDB_CACHE_DIR'] = '/data/wandb/' # required before importing wandb
26
-
27
- import logging as pylogging # To avoid collision with transformers.utils.logging
28
  import sys
29
- import time
30
  from dataclasses import dataclass, field
31
  from functools import partial
32
  from pathlib import Path
@@ -34,7 +29,6 @@ from typing import Callable, Optional
34
  import json
35
 
36
  import datasets
37
- import nltk # Here to have a nice missing dependency error message early on
38
  import numpy as np
39
  from datasets import Dataset, load_dataset, load_metric
40
  from tqdm import tqdm
@@ -51,9 +45,7 @@ from flax.jax_utils import unreplicate
51
  from flax.training import train_state
52
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
53
  from transformers import (
54
- CONFIG_MAPPING,
55
  FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
56
- AutoConfig,
57
  AutoTokenizer,
58
  FlaxAutoModelForSeq2SeqLM,
59
  FlaxBartForConditionalGeneration,
@@ -65,17 +57,9 @@ from transformers.file_utils import is_offline_mode
65
 
66
  import wandb
67
 
68
- logger = pylogging.getLogger(__name__)
69
 
70
- try:
71
- nltk.data.find("tokenizers/punkt")
72
- except (LookupError, OSError):
73
- if is_offline_mode():
74
- raise LookupError(
75
- "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
76
- )
77
- with FileLock(".lock") as lock:
78
- nltk.download("punkt", quiet=True)
79
 
80
 
81
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
@@ -87,7 +71,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
87
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
88
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
89
  BOS_TOKEN_ID = 16384
90
- BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
91
 
92
 
93
  @dataclass
@@ -105,20 +89,34 @@ class ModelArguments:
105
  )
106
  model_type: Optional[str] = field(
107
  default=None,
108
- metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
 
 
 
109
  )
110
  config_name: Optional[str] = field(
111
- default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
 
 
112
  )
113
  tokenizer_name: Optional[str] = field(
114
- default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
 
 
 
115
  )
116
  cache_dir: Optional[str] = field(
117
- default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
 
 
118
  )
119
  use_fast_tokenizer: bool = field(
120
  default=True,
121
- metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
 
122
  )
123
  dtype: Optional[str] = field(
124
  default="float32",
@@ -140,28 +138,42 @@ class DataTrainingArguments:
140
  Arguments pertaining to what data we are going to input our model for training and eval.
141
  """
142
 
143
- dataset_name: Optional[str] = field(
144
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
145
- )
146
- dataset_config_name: Optional[str] = field(
147
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
148
- )
149
  text_column: Optional[str] = field(
150
- default='caption',
151
- metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
 
 
152
  )
153
  encoding_column: Optional[str] = field(
154
- default='encoding',
155
- metadata={"help": "The name of the column in the datasets containing the image encodings."},
 
 
 
 
 
 
 
 
 
156
  )
157
- train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
158
  validation_file: Optional[str] = field(
159
  default=None,
160
- metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
 
 
 
 
 
 
 
 
 
 
161
  )
162
- test_file: Optional[str] = field(
163
  default=None,
164
- metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
165
  )
166
  max_source_length: Optional[int] = field(
167
  default=128,
@@ -171,7 +183,8 @@ class DataTrainingArguments:
171
  },
172
  )
173
  no_decay: bool = field(
174
- default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
 
175
  )
176
  max_target_length: Optional[int] = field(
177
  default=OUTPUT_LENGTH,
@@ -203,62 +216,67 @@ class DataTrainingArguments:
203
  "value if set."
204
  },
205
  )
206
- max_predict_samples: Optional[int] = field(
207
- default=None,
208
- metadata={
209
- "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
210
- "value if set."
211
- },
212
  )
213
  preprocessing_num_workers: Optional[int] = field(
214
- default=80, # ensure we have the same datasets cached data and avoid using too much space
215
  metadata={"help": "The number of processes to use for the preprocessing."},
216
  )
217
  source_prefix: Optional[str] = field(
218
- default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
219
- )
220
- predict_with_generate: bool = field(
221
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics."}
222
- )
223
- num_beams: Optional[int] = field(
224
  default=None,
225
  metadata={
226
- "help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
227
- "which is used during evaluation."
228
  },
229
  )
230
  overwrite_cache: bool = field(
231
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
232
  )
233
  log_interval: Optional[int] = field(
234
  default=40,
235
- metadata={
236
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
237
- "value if set."
238
- },
239
  )
240
  log_model: bool = field(
241
- default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
242
  )
243
  save_model_steps: Optional[int] = field(
244
- default=3000, # about once every hour in our experiments
245
  metadata={
246
  "help": "For logging the model more frequently. Used only when `log_model` is set."
247
  },
248
  )
249
 
250
  def __post_init__(self):
251
- if self.dataset_name is None and self.train_file is None and self.validation_file is None:
252
- raise ValueError("Need either a dataset name or a training/validation file.")
 
 
253
  else:
254
  if self.train_file is not None:
255
  extension = self.train_file.split(".")[-1]
256
- assert extension in ["tsv", "csv", "json"], "`train_file` should be a tsv, csv or json file."
 
 
 
 
 
257
  if self.validation_file is not None:
258
  extension = self.validation_file.split(".")[-1]
259
- assert extension in ["tsv", "csv", "json"], "`validation_file` should be a tsv, csv or json file."
 
 
 
 
 
260
  if self.val_max_target_length is None:
261
  self.val_max_target_length = self.max_target_length
 
 
 
 
262
 
263
 
264
  class TrainState(train_state.TrainState):
@@ -267,14 +285,20 @@ class TrainState(train_state.TrainState):
267
  optimizer_step: int
268
 
269
  def replicate(self):
270
- return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
 
 
271
 
272
 
273
  class CustomFlaxBartModule(FlaxBartModule):
274
  def setup(self):
275
  # check config is valid, otherwise set default values
276
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
277
- self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
 
 
 
 
278
 
279
  # we keep shared to easily load pre-trained weights
280
  self.shared = nn.Embed(
@@ -290,18 +314,29 @@ class CustomFlaxBartModule(FlaxBartModule):
290
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
291
  dtype=self.dtype,
292
  )
293
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
 
 
294
 
295
  # the decoder has a different config
296
  decoder_config = BartConfig(self.config.to_dict())
297
- decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
 
 
298
  decoder_config.vocab_size = self.config.vocab_size_output
299
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
 
 
300
 
301
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
 
 
 
302
  def setup(self):
303
  # check config is valid, otherwise set default values
304
- self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
 
 
305
 
306
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
307
  self.lm_head = nn.Dense(
@@ -310,13 +345,18 @@ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerat
310
  dtype=self.dtype,
311
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
312
  )
313
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
 
 
 
314
 
315
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
316
  module_class = CustomFlaxBartForConditionalGenerationModule
317
-
318
 
319
- def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
 
 
 
320
  """
321
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
322
  Shuffle batches if `shuffle` is `True`.
@@ -334,33 +374,58 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
334
  for idx in batch_idx:
335
  batch = dataset[idx]
336
  batch = {k: jnp.array(v) for k, v in batch.items()}
337
-
338
  batch = shard(batch)
339
-
340
  yield batch
341
 
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  def create_learning_rate_fn(
344
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
 
 
 
 
 
345
  ) -> Callable[[int], jnp.array]:
346
  """Returns a linear warmup, linear_decay learning rate function."""
347
  steps_per_epoch = train_ds_size // train_batch_size
348
  num_train_steps = steps_per_epoch * num_train_epochs
349
- warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
350
  if no_decay:
351
  return warmup_fn
352
  decay_fn = optax.linear_schedule(
353
- init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
 
 
 
 
 
354
  )
355
- schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
356
  return schedule_fn
357
 
358
 
359
  def wandb_log(metrics, step=None, prefix=None):
360
  if jax.process_index() == 0:
361
- log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
 
 
 
362
  if step is not None:
363
- log_metrics['train/step'] = step
364
  wandb.log(log_metrics)
365
 
366
 
@@ -369,11 +434,15 @@ def main():
369
  # or by passing the --help flag to this script.
370
  # We now keep distinct sets of args, for a cleaner separation of concerns.
371
 
372
- parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
 
 
373
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
374
  # If we pass only one argument to the script and it's the path to a json file,
375
  # let's parse it to get our arguments.
376
- model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
 
377
  else:
378
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
379
 
@@ -387,18 +456,6 @@ def main():
387
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
388
  "Use --overwrite_output_dir to overcome."
389
  )
390
-
391
- # Set up wandb run
392
- wandb.init(
393
- entity='wandb',
394
- project='hf-flax-dalle-mini',
395
- job_type='Seq2SeqVQGAN',
396
- config=parser.parse_args()
397
- )
398
-
399
- # set default x-axis as 'train/step'
400
- wandb.define_metric('train/step')
401
- wandb.define_metric('*', step_metric='train/step')
402
 
403
  # Make one log on every process with the configuration for debugging.
404
  pylogging.basicConfig(
@@ -422,16 +479,15 @@ def main():
422
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
423
  # (the dataset will be downloaded automatically from the datasets Hub).
424
  #
425
- data_files = {}
426
- if data_args.train_file is not None:
427
- data_files["train"] = data_args.train_file
428
- if data_args.validation_file is not None:
429
- data_files["validation"] = data_args.validation_file
430
- if data_args.test_file is not None:
431
- data_files["test"] = data_args.test_file
432
- dataset = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir, delimiter="\t")
433
- # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
434
- # https://huggingface.co/docs/datasets/loading_datasets.html.
435
 
436
  # Set up items to load or create
437
  tokenizer = None
@@ -439,18 +495,29 @@ def main():
439
 
440
  def restore_state(state, artifact_dir):
441
  # restore optimizer state
442
- if (Path(artifact_dir) / 'opt_state.msgpack').exists():
443
- with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
444
- opt_state = from_bytes(state.opt_state, f.read())
445
-
446
  # restore steps
447
- if (Path(artifact_dir) / 'training_state.json').exists():
448
- with (Path(artifact_dir) / 'training_state.json').open('r') as f:
449
- training_state = json.load(f)
450
- step = training_state['step']
451
- optimizer_step = step // training_args.gradient_accumulation_steps
452
- state.replace(step=step, optimizer_step=optimizer_step)
453
-
 
 
 
 
 
 
 
 
 
 
 
 
454
  if model_args.from_checkpoint is not None:
455
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
456
  artifact_dir = artifact.download()
@@ -466,40 +533,54 @@ def main():
466
  config = model.config
467
 
468
  # load tokenizer if present
469
- if (Path(artifact_dir) / 'tokenizer_config.json').exists():
470
  tokenizer = AutoTokenizer.from_pretrained(
471
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
472
- )
 
 
473
 
474
  else:
475
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
476
- model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
 
477
  )
478
  # Set up our new model config
479
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
480
  config.tie_word_embeddings = False
481
  config.decoder_start_token_id = BOS_TOKEN_ID # for first token
482
- config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
483
- config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
 
 
 
 
484
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
485
  config.forced_bos_token_id = None # we don't need this token
486
  config.forced_eos_token_id = None # we don't need this token
487
- config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
 
 
488
  config.min_length = data_args.max_target_length
489
  config.max_length = data_args.max_target_length
490
 
491
  # Create a custom model and initialize it randomly
492
- model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
493
 
494
  # Use pre-trained weights for encoder
495
- model.params['model']['encoder'] = base_model.params['model']['encoder']
496
- model.params['model']['shared'] = base_model.params['model']['shared']
497
  del base_model
498
 
499
  # Load tokenizer if it has not been set
500
  if tokenizer is None:
501
  tokenizer = AutoTokenizer.from_pretrained(
502
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
503
  )
504
 
505
  print(f"TPUs: {jax.device_count()}")
@@ -509,23 +590,11 @@ def main():
509
 
510
  # Preprocessing the datasets.
511
  # We need to tokenize inputs and targets.
512
- if training_args.do_train:
513
- column_names = dataset["train"].column_names
514
- elif training_args.do_eval:
515
- column_names = dataset["validation"].column_names
516
- elif training_args.do_predict:
517
- column_names = dataset["test"].column_names
518
- else:
519
- logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
520
- return
521
 
522
  # Get the column names for input/target.
523
  text_column = data_args.text_column
524
  encoding_column = data_args.encoding_column
525
 
526
- # Temporarily set max_target_length for training.
527
- max_target_length = data_args.max_target_length
528
-
529
  def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
530
  """
531
  Shift input ids one token to the right.
@@ -535,18 +604,28 @@ def main():
535
  shifted_input_ids[:, 0] = decoder_start_token_id
536
  return shifted_input_ids
537
 
 
 
 
 
 
 
538
  def preprocess_function(examples):
539
  inputs = examples[text_column]
540
- inputs = [prefix + inp for inp in inputs]
541
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
542
  model_inputs = tokenizer(
543
- inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
 
 
 
 
544
  )
545
 
546
  # set up targets
547
  # Note: labels correspond to our target indices
548
  # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
549
- labels = [eval(indices) for indices in examples['encoding']]
550
  labels = np.asarray(labels)
551
 
552
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
@@ -563,46 +642,75 @@ def main():
563
  raise ValueError("--do_train requires a train dataset")
564
  train_dataset = dataset["train"]
565
  if data_args.max_train_samples is not None:
566
- train_dataset = train_dataset.select(range(data_args.max_train_samples))
567
- train_dataset = train_dataset.map(
568
- preprocess_function,
569
- batched=True,
570
- num_proc=data_args.preprocessing_num_workers,
571
- remove_columns=column_names,
572
- load_from_cache_file=not data_args.overwrite_cache,
573
- desc="Running tokenizer on train dataset",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  )
575
 
576
  if training_args.do_eval:
577
- max_target_length = data_args.val_max_target_length
578
  if "validation" not in dataset:
579
  raise ValueError("--do_eval requires a validation dataset")
580
  eval_dataset = dataset["validation"]
581
  if data_args.max_eval_samples is not None:
582
- eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
583
- eval_dataset = eval_dataset.map(
584
- preprocess_function,
585
- batched=True,
586
- num_proc=data_args.preprocessing_num_workers,
587
- remove_columns=column_names,
588
- load_from_cache_file=not data_args.overwrite_cache,
589
- desc="Running tokenizer on validation dataset",
590
- )
591
-
592
- if training_args.do_predict:
593
- max_target_length = data_args.val_max_target_length
594
- if "test" not in dataset:
595
- raise ValueError("--do_predict requires a test dataset")
596
- predict_dataset = dataset["test"]
597
- if data_args.max_predict_samples is not None:
598
- predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
599
- predict_dataset = predict_dataset.map(
600
- preprocess_function,
601
- batched=True,
602
- num_proc=data_args.preprocessing_num_workers,
603
- remove_columns=column_names,
604
- load_from_cache_file=not data_args.overwrite_cache,
605
- desc="Running tokenizer on prediction dataset",
 
 
 
 
 
 
606
  )
607
 
608
  # Initialize our training
@@ -611,21 +719,40 @@ def main():
611
 
612
  # Store some constant
613
  num_epochs = int(training_args.num_train_epochs)
614
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
 
 
615
  total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
616
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
617
- steps_per_epoch = len(train_dataset) // train_batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  total_steps = steps_per_epoch * num_epochs
619
- total_optimization_steps = (len(train_dataset) // total_batch_size) * num_epochs
620
 
621
  # Create learning rate schedule
622
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
623
- len(train_dataset),
624
  total_batch_size,
625
  training_args.num_train_epochs,
626
  training_args.warmup_steps,
627
  training_args.learning_rate,
628
- data_args.no_decay
629
  )
630
 
631
  # We use Optax's "masking" functionality to not apply weight decay
@@ -638,9 +765,17 @@ def main():
638
  def decay_mask_fn(params):
639
  flat_params = traverse_util.flatten_dict(params)
640
  layer_norm_params = [
641
- (name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
 
 
 
 
 
642
  ]
643
- flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
 
 
 
644
  return traverse_util.unflatten_dict(flat_mask)
645
 
646
  # create adam optimizer
@@ -671,7 +806,10 @@ def main():
671
  )
672
  if model_args.from_checkpoint is not None:
673
  # restore optimizer state, step and optimizer_step
674
- restore_state(state, artifact_dir)
 
 
 
675
 
676
  # label smoothed cross entropy
677
  def loss_fn(logits, labels):
@@ -685,7 +823,9 @@ def main():
685
 
686
  def compute_loss(params):
687
  labels = batch.pop("labels")
688
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
 
689
  loss = loss_fn(logits, labels)
690
  return loss
691
 
@@ -694,10 +834,14 @@ def main():
694
  grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
695
 
696
  def update_fn():
697
- grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
 
 
698
  grads = jax.lax.pmean(grads, "batch")
699
  new_state = state.apply_gradients(
700
- grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
 
 
701
  )
702
  return new_state
703
 
@@ -708,7 +852,10 @@ def main():
708
  None,
709
  )
710
 
711
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step)}
 
 
 
712
  metrics = jax.lax.pmean(metrics, axis_name="batch")
713
 
714
  return new_state.replace(dropout_rng=new_dropout_rng), metrics
@@ -724,39 +871,25 @@ def main():
724
  metrics = jax.lax.pmean(metrics, axis_name="batch")
725
  return metrics
726
 
727
- # Define generation function
728
- max_length = (
729
- data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length
730
- )
731
- num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams
732
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
733
-
734
- def generate_step(params, batch):
735
- model.params = params
736
- output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs)
737
- return output_ids.sequences
738
-
739
  # Create parallel version of the train and eval step
740
- p_train_step = jax.pmap(
741
- train_step, "batch", donate_argnums=(0,)
742
- )
743
  p_eval_step = jax.pmap(eval_step, "batch")
744
- p_generate_step = jax.pmap(generate_step, "batch")
745
 
746
  # Replicate the train state on each device
747
  state = state.replicate()
748
 
749
  logger.info("***** Running training *****")
750
- logger.info(f" Num examples = {len(train_dataset)}")
751
  logger.info(f" Num Epochs = {num_epochs}")
752
- logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
 
 
753
  logger.info(
754
  f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
755
  )
756
  logger.info(f" Total global steps = {total_steps}")
757
  logger.info(f" Total optimization steps = {total_optimization_steps}")
758
 
759
- train_time = 0
760
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
761
  global_step = 0
762
 
@@ -764,31 +897,28 @@ def main():
764
  # ======================== Evaluating ==============================
765
  eval_metrics = []
766
  if training_args.do_eval:
767
- eval_preds = []
768
- eval_labels = []
769
-
770
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
771
- eval_steps = len(eval_dataset) // eval_batch_size
772
- for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
 
 
 
 
 
 
773
  # Model forward
774
- batch = next(eval_loader)
775
- labels = batch["labels"]
776
-
777
  metrics = p_eval_step(state.params, batch)
778
  eval_metrics.append(metrics)
779
 
780
- # generation
781
- if data_args.predict_with_generate:
782
- generated_ids = p_generate_step(state.params, batch)
783
- eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
784
- eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
785
-
786
  # normalize eval metrics
787
  eval_metrics = get_metrics(eval_metrics)
788
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
789
 
790
  # log metrics
791
- wandb_log(eval_metrics, step=global_step, prefix='eval')
792
 
793
  # Print metrics and update progress bar
794
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -812,30 +942,48 @@ def main():
812
 
813
  # save state
814
  state = unreplicate(state)
815
- with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
816
  f.write(to_bytes(state.opt_state))
817
- with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
818
- json.dump({'step': state.step.item()}, f)
 
 
819
 
820
  # save to W&B
821
  if data_args.log_model:
822
- metadata = {'step': step, 'epoch': epoch}
823
  if eval_metrics is not None:
824
- metadata['eval/loss'] = eval_metrics['loss']
825
  artifact = wandb.Artifact(
826
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
827
  )
828
- artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
829
- artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
830
- artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer.json'))
831
- artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
832
- artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
833
- artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
834
- artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
835
- artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
836
- artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
 
 
 
 
 
 
 
 
 
 
 
 
837
  wandb.run.log_artifact(artifact)
838
 
 
 
 
 
839
  # save to the hub
840
  if training_args.push_to_hub:
841
  model.save_pretrained(
@@ -843,39 +991,48 @@ def main():
843
  params=params,
844
  push_to_hub=training_args.push_to_hub,
845
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
846
- temp_dir=True # avoid issues with being in a repository
847
  )
848
-
849
  for epoch in epochs:
850
  # ======================== Training ================================
851
- train_start = time.time()
852
 
853
  # Create sampling rng
854
  rng, input_rng = jax.random.split(rng)
855
 
856
  # Generate an epoch by shuffling sampling indices from the train dataset
857
- train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
858
- steps_per_epoch = len(train_dataset) // train_batch_size
 
 
 
 
 
859
  # train
860
- for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
861
- global_step +=1
862
- batch = next(train_loader)
 
 
 
 
 
863
  state, train_metric = p_train_step(state, batch)
864
 
865
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
866
  # log metrics
867
- wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
868
 
869
- if global_step % training_args.eval_steps == 0:
870
  run_evaluation()
871
-
872
  if global_step % data_args.save_model_steps == 0:
873
  run_save_model(state, global_step, epoch)
874
-
875
  # log final train metrics
876
- wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
877
 
878
- train_time += time.time() - train_start
879
  train_metric = unreplicate(train_metric)
880
  epochs.write(
881
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
@@ -888,38 +1045,5 @@ def main():
888
  run_save_model(state, global_step, epoch, eval_metrics)
889
 
890
 
891
- # ======================== Prediction loop ==============================
892
- if training_args.do_predict:
893
- logger.info("*** Predict ***")
894
-
895
- pred_metrics = []
896
- pred_generations = []
897
- pred_labels = []
898
-
899
- pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
900
- pred_steps = len(predict_dataset) // eval_batch_size
901
- for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
902
- # Model forward
903
- batch = next(pred_loader)
904
- labels = batch["labels"]
905
-
906
- metrics = p_eval_step(state.params, batch)
907
- pred_metrics.append(metrics)
908
-
909
- # generation
910
- if data_args.predict_with_generate:
911
- generated_ids = p_generate_step(state.params, batch)
912
- pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
913
- pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
914
-
915
- # normalize prediction metrics
916
- pred_metrics = get_metrics(pred_metrics)
917
- pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
918
-
919
- # Print metrics
920
- desc = f"Predict Loss: {pred_metrics['loss']})"
921
- logger.info(desc)
922
-
923
-
924
  if __name__ == "__main__":
925
  main()
 
20
  # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
21
 
22
  import os
23
+ import logging as pylogging # To avoid collision with transformers.utils.logging
 
 
 
 
24
  import sys
 
25
  from dataclasses import dataclass, field
26
  from functools import partial
27
  from pathlib import Path
 
29
  import json
30
 
31
  import datasets
 
32
  import numpy as np
33
  from datasets import Dataset, load_dataset, load_metric
34
  from tqdm import tqdm
 
45
  from flax.training import train_state
46
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
  from transformers import (
 
48
  FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
 
49
  AutoTokenizer,
50
  FlaxAutoModelForSeq2SeqLM,
51
  FlaxBartForConditionalGeneration,
 
57
 
58
  import wandb
59
 
60
+ from dalle_mini.text import TextNormalizer
61
 
62
+ logger = pylogging.getLogger(__name__)
 
 
 
 
 
 
 
 
63
 
64
 
65
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
 
71
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
72
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
73
  BOS_TOKEN_ID = 16384
74
+ BASE_MODEL = "facebook/bart-large-cnn" # we currently have issues with bart-large
75
 
76
 
77
  @dataclass
 
89
  )
90
  model_type: Optional[str] = field(
91
  default=None,
92
+ metadata={
93
+ "help": "If training from scratch, pass a model type from the list: "
94
+ + ", ".join(MODEL_TYPES)
95
+ },
96
  )
97
  config_name: Optional[str] = field(
98
+ default=None,
99
+ metadata={
100
+ "help": "Pretrained config name or path if not the same as model_name"
101
+ },
102
  )
103
  tokenizer_name: Optional[str] = field(
104
+ default=None,
105
+ metadata={
106
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
107
+ },
108
  )
109
  cache_dir: Optional[str] = field(
110
+ default=None,
111
+ metadata={
112
+ "help": "Where do you want to store the pretrained models downloaded from s3"
113
+ },
114
  )
115
  use_fast_tokenizer: bool = field(
116
  default=True,
117
+ metadata={
118
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
119
+ },
120
  )
121
  dtype: Optional[str] = field(
122
  default="float32",
 
138
  Arguments pertaining to what data we are going to input our model for training and eval.
139
  """
140
 
 
 
 
 
 
 
141
  text_column: Optional[str] = field(
142
+ default="caption",
143
+ metadata={
144
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
145
+ },
146
  )
147
  encoding_column: Optional[str] = field(
148
+ default="encoding",
149
+ metadata={
150
+ "help": "The name of the column in the datasets containing the image encodings."
151
+ },
152
+ )
153
+ dataset_repo_or_path: Optional[str] = field(
154
+ default=None,
155
+ metadata={"help": "The dataset repository containing encoded files."},
156
+ )
157
+ train_file: Optional[str] = field(
158
+ default=None, metadata={"help": "The input training data file (a text file)."}
159
  )
 
160
  validation_file: Optional[str] = field(
161
  default=None,
162
+ metadata={
163
+ "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
164
+ },
165
+ )
166
+ streaming: bool = field(
167
+ default=False,
168
+ metadata={"help": "Whether to stream the dataset."},
169
+ )
170
+ len_train: Optional[int] = field(
171
+ default=None,
172
+ metadata={"help": "Length of training dataset, required for streaming"},
173
  )
174
+ len_eval: Optional[int] = field(
175
  default=None,
176
+ metadata={"help": "Length of validation dataset, required for streaming"},
177
  )
178
  max_source_length: Optional[int] = field(
179
  default=128,
 
183
  },
184
  )
185
  no_decay: bool = field(
186
+ default=False,
187
+ metadata={"help": "Whether to use decay in the learning rate scheduler."},
188
  )
189
  max_target_length: Optional[int] = field(
190
  default=OUTPUT_LENGTH,
 
216
  "value if set."
217
  },
218
  )
219
+ normalize_text: bool = field(
220
+ default=False,
221
+ metadata={"help": "Normalize/Simplify text"},
 
 
 
222
  )
223
  preprocessing_num_workers: Optional[int] = field(
224
+ default=80, # ensure we have the same datasets cached data and avoid using too much space
225
  metadata={"help": "The number of processes to use for the preprocessing."},
226
  )
227
  source_prefix: Optional[str] = field(
 
 
 
 
 
 
228
  default=None,
229
  metadata={
230
+ "help": "A prefix to add before every source text (useful for T5 models)."
 
231
  },
232
  )
233
  overwrite_cache: bool = field(
234
+ default=False,
235
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
236
  )
237
  log_interval: Optional[int] = field(
238
  default=40,
239
+ metadata={"help": "Log frequency for metrics"},
 
 
 
240
  )
241
  log_model: bool = field(
242
+ default=False,
243
+ metadata={"help": "Overwrite the cached training and evaluation sets"},
244
  )
245
  save_model_steps: Optional[int] = field(
246
+ default=5000, # about once every 1.5h in our experiments
247
  metadata={
248
  "help": "For logging the model more frequently. Used only when `log_model` is set."
249
  },
250
  )
251
 
252
  def __post_init__(self):
253
+ if self.dataset_repo_or_path is None:
254
+ raise ValueError("Need a dataset repository or path.")
255
+ if self.train_file is None or self.validation_file is None:
256
+ raise ValueError("Need training/validation file.")
257
  else:
258
  if self.train_file is not None:
259
  extension = self.train_file.split(".")[-1]
260
+ assert extension in [
261
+ "tsv",
262
+ "csv",
263
+ "json",
264
+ "jsonl",
265
+ ], "`train_file` should be a tsv, csv or json file."
266
  if self.validation_file is not None:
267
  extension = self.validation_file.split(".")[-1]
268
+ assert extension in [
269
+ "tsv",
270
+ "csv",
271
+ "json",
272
+ "jsonl",
273
+ ], "`validation_file` should be a tsv, csv or json file."
274
  if self.val_max_target_length is None:
275
  self.val_max_target_length = self.max_target_length
276
+ if self.streaming and (self.len_train is None or self.len_eval is None):
277
+ raise ValueError(
278
+ "Streaming requires providing length of training and validation datasets"
279
+ )
280
 
281
 
282
  class TrainState(train_state.TrainState):
 
285
  optimizer_step: int
286
 
287
  def replicate(self):
288
+ return jax_utils.replicate(self).replace(
289
+ dropout_rng=shard_prng_key(self.dropout_rng)
290
+ )
291
 
292
 
293
  class CustomFlaxBartModule(FlaxBartModule):
294
  def setup(self):
295
  # check config is valid, otherwise set default values
296
+ self.config.vocab_size_output = getattr(
297
+ self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
298
+ )
299
+ self.config.max_position_embeddings_decoder = getattr(
300
+ self.config, "max_position_embeddings_decoder", OUTPUT_LENGTH
301
+ )
302
 
303
  # we keep shared to easily load pre-trained weights
304
  self.shared = nn.Embed(
 
314
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
315
  dtype=self.dtype,
316
  )
317
+ self.encoder = FlaxBartEncoder(
318
+ self.config, dtype=self.dtype, embed_tokens=self.shared
319
+ )
320
 
321
  # the decoder has a different config
322
  decoder_config = BartConfig(self.config.to_dict())
323
+ decoder_config.max_position_embeddings = (
324
+ self.config.max_position_embeddings_decoder
325
+ )
326
  decoder_config.vocab_size = self.config.vocab_size_output
327
+ self.decoder = FlaxBartDecoder(
328
+ decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
329
+ )
330
 
331
+
332
+ class CustomFlaxBartForConditionalGenerationModule(
333
+ FlaxBartForConditionalGenerationModule
334
+ ):
335
  def setup(self):
336
  # check config is valid, otherwise set default values
337
+ self.config.vocab_size_output = getattr(
338
+ self.config, "vocab_size_output", OUTPUT_VOCAB_SIZE
339
+ )
340
 
341
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
342
  self.lm_head = nn.Dense(
 
345
  dtype=self.dtype,
346
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
347
  )
348
+ self.final_logits_bias = self.param(
349
+ "final_logits_bias", self.bias_init, (1, self.config.vocab_size_output)
350
+ )
351
+
352
 
353
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
354
  module_class = CustomFlaxBartForConditionalGenerationModule
 
355
 
356
+
357
+ def data_loader(
358
+ rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
359
+ ):
360
  """
361
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
362
  Shuffle batches if `shuffle` is `True`.
 
374
  for idx in batch_idx:
375
  batch = dataset[idx]
376
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
377
  batch = shard(batch)
 
378
  yield batch
379
 
380
 
381
+ def data_loader_streaming(dataset: Dataset, batch_size: int):
382
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
383
+ batch = {k: [] for k in keys}
384
+ for item in dataset:
385
+ for k, v in item.items():
386
+ batch[k].append(v)
387
+ if len(batch[keys[0]]) == batch_size:
388
+ batch = {k: jnp.array(v) for k, v in batch.items()}
389
+ batch = shard(batch)
390
+ yield batch
391
+ batch = {k: [] for k in keys}
392
+
393
+
394
  def create_learning_rate_fn(
395
+ train_ds_size: int,
396
+ train_batch_size: int,
397
+ num_train_epochs: int,
398
+ num_warmup_steps: int,
399
+ learning_rate: float,
400
+ no_decay: bool,
401
  ) -> Callable[[int], jnp.array]:
402
  """Returns a linear warmup, linear_decay learning rate function."""
403
  steps_per_epoch = train_ds_size // train_batch_size
404
  num_train_steps = steps_per_epoch * num_train_epochs
405
+ warmup_fn = optax.linear_schedule(
406
+ init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
407
+ )
408
  if no_decay:
409
  return warmup_fn
410
  decay_fn = optax.linear_schedule(
411
+ init_value=learning_rate,
412
+ end_value=0,
413
+ transition_steps=num_train_steps - num_warmup_steps,
414
+ )
415
+ schedule_fn = optax.join_schedules(
416
+ schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
417
  )
 
418
  return schedule_fn
419
 
420
 
421
  def wandb_log(metrics, step=None, prefix=None):
422
  if jax.process_index() == 0:
423
+ log_metrics = {
424
+ f"{prefix}/{k}" if prefix is not None else k: jax.device_get(v)
425
+ for k, v in metrics.items()
426
+ }
427
  if step is not None:
428
+ log_metrics["train/step"] = step
429
  wandb.log(log_metrics)
430
 
431
 
 
434
  # or by passing the --help flag to this script.
435
  # We now keep distinct sets of args, for a cleaner separation of concerns.
436
 
437
+ parser = HfArgumentParser(
438
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
439
+ )
440
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
441
  # If we pass only one argument to the script and it's the path to a json file,
442
  # let's parse it to get our arguments.
443
+ model_args, data_args, training_args = parser.parse_json_file(
444
+ json_file=os.path.abspath(sys.argv[1])
445
+ )
446
  else:
447
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
448
 
 
456
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
457
  "Use --overwrite_output_dir to overcome."
458
  )
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  # Make one log on every process with the configuration for debugging.
461
  pylogging.basicConfig(
 
479
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
480
  # (the dataset will be downloaded automatically from the datasets Hub).
481
  #
482
+ data_files = {
483
+ "train": data_args.train_file,
484
+ "validation": data_args.validation_file,
485
+ }
486
+ dataset = load_dataset(
487
+ data_args.dataset_repo_or_path,
488
+ data_files=data_files,
489
+ streaming=data_args.streaming,
490
+ )
 
491
 
492
  # Set up items to load or create
493
  tokenizer = None
 
495
 
496
  def restore_state(state, artifact_dir):
497
  # restore optimizer state
498
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
499
+ opt_state = from_bytes(state.opt_state, f.read())
500
+
 
501
  # restore steps
502
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
503
+ training_state = json.load(f)
504
+ step = training_state["step"]
505
+ optimizer_step = step // training_args.gradient_accumulation_steps
506
+
507
+ return step, optimizer_step, opt_state
508
+
509
+ # Set up wandb run
510
+ wandb.init(
511
+ entity="dalle-mini",
512
+ project="dalle-mini",
513
+ job_type="Seq2Seq",
514
+ config=parser.parse_args(),
515
+ )
516
+
517
+ # set default x-axis as 'train/step'
518
+ wandb.define_metric("train/step")
519
+ wandb.define_metric("*", step_metric="train/step")
520
+
521
  if model_args.from_checkpoint is not None:
522
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
523
  artifact_dir = artifact.download()
 
533
  config = model.config
534
 
535
  # load tokenizer if present
536
+ if (Path(artifact_dir) / "tokenizer_config.json").exists():
537
  tokenizer = AutoTokenizer.from_pretrained(
538
+ model_args.model_name_or_path,
539
+ cache_dir=model_args.cache_dir,
540
+ use_fast=model_args.use_fast_tokenizer,
541
+ )
542
 
543
  else:
544
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
545
+ model_args.model_name_or_path,
546
+ seed=training_args.seed,
547
+ dtype=getattr(jnp, model_args.dtype),
548
  )
549
  # Set up our new model config
550
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
551
  config.tie_word_embeddings = False
552
  config.decoder_start_token_id = BOS_TOKEN_ID # for first token
553
+ config.bos_token_id = (
554
+ BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
555
+ )
556
+ config.pos_token_id = (
557
+ BOS_TOKEN_ID # should not be needed (as we generate until max_length)
558
+ )
559
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
560
  config.forced_bos_token_id = None # we don't need this token
561
  config.forced_eos_token_id = None # we don't need this token
562
+ config.force_bos_token_to_be_generated = (
563
+ False # otherwise it sets bos_token_id at loading
564
+ )
565
  config.min_length = data_args.max_target_length
566
  config.max_length = data_args.max_target_length
567
 
568
  # Create a custom model and initialize it randomly
569
+ model = CustomFlaxBartForConditionalGeneration(
570
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
571
+ )
572
 
573
  # Use pre-trained weights for encoder
574
+ model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
575
+ model.params["model"]["shared"] = base_model.params["model"]["shared"]
576
  del base_model
577
 
578
  # Load tokenizer if it has not been set
579
  if tokenizer is None:
580
  tokenizer = AutoTokenizer.from_pretrained(
581
+ model_args.model_name_or_path,
582
+ cache_dir=model_args.cache_dir,
583
+ use_fast=model_args.use_fast_tokenizer,
584
  )
585
 
586
  print(f"TPUs: {jax.device_count()}")
 
590
 
591
  # Preprocessing the datasets.
592
  # We need to tokenize inputs and targets.
 
 
 
 
 
 
 
 
 
593
 
594
  # Get the column names for input/target.
595
  text_column = data_args.text_column
596
  encoding_column = data_args.encoding_column
597
 
 
 
 
598
  def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
599
  """
600
  Shift input ids one token to the right.
 
604
  shifted_input_ids[:, 0] = decoder_start_token_id
605
  return shifted_input_ids
606
 
607
+ text_normalizer = TextNormalizer() if data_args.normalize_text else None
608
+
609
+ def normalize_text(example):
610
+ example[text_column] = text_normalizer(example[text_column])
611
+ return example
612
+
613
  def preprocess_function(examples):
614
  inputs = examples[text_column]
615
+ inputs = [prefix + inp for inp in inputs] if prefix else inputs
616
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
617
  model_inputs = tokenizer(
618
+ inputs,
619
+ max_length=data_args.max_source_length,
620
+ padding="max_length",
621
+ truncation=True,
622
+ return_tensors="np",
623
  )
624
 
625
  # set up targets
626
  # Note: labels correspond to our target indices
627
  # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
628
+ labels = examples[encoding_column]
629
  labels = np.asarray(labels)
630
 
631
  # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
 
642
  raise ValueError("--do_train requires a train dataset")
643
  train_dataset = dataset["train"]
644
  if data_args.max_train_samples is not None:
645
+ train_dataset = (
646
+ train_dataset.take(data_args.max_train_samples)
647
+ if data_args.streaming
648
+ else train_dataset.select(range(data_args.max_train_samples))
649
+ )
650
+ if data_args.streaming:
651
+ train_dataset = train_dataset.shuffle(1000, training_args.seed)
652
+ if data_args.normalize_text:
653
+ train_dataset = (
654
+ train_dataset.map(normalize_text)
655
+ if data_args.streaming
656
+ else train_dataset.map(
657
+ normalize_text,
658
+ num_proc=data_args.preprocessing_num_workers,
659
+ load_from_cache_file=not data_args.overwrite_cache,
660
+ desc="Normalizing the validation dataset",
661
+ )
662
+ )
663
+ train_dataset = (
664
+ train_dataset.map(
665
+ preprocess_function,
666
+ batched=True,
667
+ )
668
+ if data_args.streaming
669
+ else train_dataset.map(
670
+ preprocess_function,
671
+ batched=True,
672
+ num_proc=data_args.preprocessing_num_workers,
673
+ remove_columns=train_dataset.column_names,
674
+ load_from_cache_file=not data_args.overwrite_cache,
675
+ desc="Running tokenizer on validation dataset",
676
+ )
677
  )
678
 
679
  if training_args.do_eval:
 
680
  if "validation" not in dataset:
681
  raise ValueError("--do_eval requires a validation dataset")
682
  eval_dataset = dataset["validation"]
683
  if data_args.max_eval_samples is not None:
684
+ eval_dataset = (
685
+ eval_dataset.take(data_args.max_train_samples)
686
+ if data_args.streaming
687
+ else eval_dataset.select(range(data_args.max_train_samples))
688
+ )
689
+ if data_args.normalize_text:
690
+ eval_dataset = (
691
+ eval_dataset.map(normalize_text)
692
+ if data_args.streaming
693
+ else eval_dataset.map(
694
+ normalize_text,
695
+ num_proc=data_args.preprocessing_num_workers,
696
+ load_from_cache_file=not data_args.overwrite_cache,
697
+ desc="Normalizing the validation dataset",
698
+ )
699
+ )
700
+ eval_dataset = (
701
+ eval_dataset.map(
702
+ preprocess_function,
703
+ batched=True,
704
+ )
705
+ if data_args.streaming
706
+ else eval_dataset.map(
707
+ preprocess_function,
708
+ batched=True,
709
+ num_proc=data_args.preprocessing_num_workers,
710
+ remove_columns=eval_dataset.column_names,
711
+ load_from_cache_file=not data_args.overwrite_cache,
712
+ desc="Running tokenizer on validation dataset",
713
+ )
714
  )
715
 
716
  # Initialize our training
 
719
 
720
  # Store some constant
721
  num_epochs = int(training_args.num_train_epochs)
722
+ train_batch_size = (
723
+ int(training_args.per_device_train_batch_size) * jax.device_count()
724
+ )
725
  total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
726
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
727
+ if data_args.streaming:
728
+ len_train_dataset = data_args.len_train
729
+ if (
730
+ data_args.max_train_samples is not None
731
+ and data_args.max_train_samples < len_train_dataset
732
+ ):
733
+ len_train_dataset = data_args.max_train_samples
734
+
735
+ len_eval_dataset = data_args.len_eval
736
+ if (
737
+ data_args.max_eval_samples is not None
738
+ and data_args.max_eval_samples < len_eval_dataset
739
+ ):
740
+ len_eval_dataset = data_args.max_eval_samples
741
+ else:
742
+ len_train_dataset = len(train_dataset)
743
+ len_eval_dataset = len(eval_dataset)
744
+ steps_per_epoch = len_train_dataset // train_batch_size
745
  total_steps = steps_per_epoch * num_epochs
746
+ total_optimization_steps = (len_train_dataset // total_batch_size) * num_epochs
747
 
748
  # Create learning rate schedule
749
  linear_decay_lr_schedule_fn = create_learning_rate_fn(
750
+ len_train_dataset,
751
  total_batch_size,
752
  training_args.num_train_epochs,
753
  training_args.warmup_steps,
754
  training_args.learning_rate,
755
+ data_args.no_decay,
756
  )
757
 
758
  # We use Optax's "masking" functionality to not apply weight decay
 
765
  def decay_mask_fn(params):
766
  flat_params = traverse_util.flatten_dict(params)
767
  layer_norm_params = [
768
+ (name, "scale")
769
+ for name in [
770
+ "self_attn_layer_norm",
771
+ "layernorm_embedding",
772
+ "final_layer_norm",
773
+ ]
774
  ]
775
+ flat_mask = {
776
+ path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
777
+ for path in flat_params
778
+ }
779
  return traverse_util.unflatten_dict(flat_mask)
780
 
781
  # create adam optimizer
 
806
  )
807
  if model_args.from_checkpoint is not None:
808
  # restore optimizer state, step and optimizer_step
809
+ step, optimizer_step, opt_state = restore_state(state, artifact_dir)
810
+ state = state.replace(
811
+ step=step, optimizer_step=optimizer_step, opt_state=opt_state
812
+ )
813
 
814
  # label smoothed cross entropy
815
  def loss_fn(logits, labels):
 
823
 
824
  def compute_loss(params):
825
  labels = batch.pop("labels")
826
+ logits = state.apply_fn(
827
+ **batch, params=params, dropout_rng=dropout_rng, train=True
828
+ )[0]
829
  loss = loss_fn(logits, labels)
830
  return loss
831
 
 
834
  grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
835
 
836
  def update_fn():
837
+ grads = jax.tree_map(
838
+ lambda x: x / training_args.gradient_accumulation_steps, grad_accum
839
+ )
840
  grads = jax.lax.pmean(grads, "batch")
841
  new_state = state.apply_gradients(
842
+ grads=grads,
843
+ grad_accum=jax.tree_map(jnp.zeros_like, grads),
844
+ optimizer_step=state.optimizer_step + 1,
845
  )
846
  return new_state
847
 
 
852
  None,
853
  )
854
 
855
+ metrics = {
856
+ "loss": loss,
857
+ "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step),
858
+ }
859
  metrics = jax.lax.pmean(metrics, axis_name="batch")
860
 
861
  return new_state.replace(dropout_rng=new_dropout_rng), metrics
 
871
  metrics = jax.lax.pmean(metrics, axis_name="batch")
872
  return metrics
873
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  # Create parallel version of the train and eval step
875
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
 
 
876
  p_eval_step = jax.pmap(eval_step, "batch")
 
877
 
878
  # Replicate the train state on each device
879
  state = state.replicate()
880
 
881
  logger.info("***** Running training *****")
882
+ logger.info(f" Num examples = {len_train_dataset}")
883
  logger.info(f" Num Epochs = {num_epochs}")
884
+ logger.info(
885
+ f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
886
+ )
887
  logger.info(
888
  f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
889
  )
890
  logger.info(f" Total global steps = {total_steps}")
891
  logger.info(f" Total optimization steps = {total_optimization_steps}")
892
 
 
893
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
894
  global_step = 0
895
 
 
897
  # ======================== Evaluating ==============================
898
  eval_metrics = []
899
  if training_args.do_eval:
900
+ if data_args.streaming:
901
+ eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
902
+ else:
903
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
904
+ eval_steps = len_eval_dataset // eval_batch_size
905
+ for batch in tqdm(
906
+ eval_loader,
907
+ desc="Evaluating...",
908
+ position=2,
909
+ leave=False,
910
+ total=eval_steps,
911
+ ):
912
  # Model forward
 
 
 
913
  metrics = p_eval_step(state.params, batch)
914
  eval_metrics.append(metrics)
915
 
 
 
 
 
 
 
916
  # normalize eval metrics
917
  eval_metrics = get_metrics(eval_metrics)
918
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
919
 
920
  # log metrics
921
+ wandb_log(eval_metrics, step=global_step, prefix="eval")
922
 
923
  # Print metrics and update progress bar
924
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
942
 
943
  # save state
944
  state = unreplicate(state)
945
+ with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
946
  f.write(to_bytes(state.opt_state))
947
+ with (Path(training_args.output_dir) / "training_state.json").open(
948
+ "w"
949
+ ) as f:
950
+ json.dump({"step": state.step.item()}, f)
951
 
952
  # save to W&B
953
  if data_args.log_model:
954
+ metadata = {"step": step, "epoch": epoch}
955
  if eval_metrics is not None:
956
+ metadata["eval/loss"] = eval_metrics["loss"]
957
  artifact = wandb.Artifact(
958
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
959
  )
960
+ artifact.add_file(
961
+ str(Path(training_args.output_dir) / "flax_model.msgpack")
962
+ )
963
+ artifact.add_file(str(Path(training_args.output_dir) / "config.json"))
964
+ artifact.add_file(
965
+ str(Path(training_args.output_dir) / "tokenizer.json")
966
+ )
967
+ artifact.add_file(
968
+ str(Path(training_args.output_dir) / "tokenizer_config.json")
969
+ )
970
+ artifact.add_file(str(Path(training_args.output_dir) / "vocab.json"))
971
+ artifact.add_file(str(Path(training_args.output_dir) / "merges.txt"))
972
+ artifact.add_file(
973
+ str(Path(training_args.output_dir) / "special_tokens_map.json")
974
+ )
975
+ artifact.add_file(
976
+ str(Path(training_args.output_dir) / "opt_state.msgpack")
977
+ )
978
+ artifact.add_file(
979
+ str(Path(training_args.output_dir) / "training_state.json")
980
+ )
981
  wandb.run.log_artifact(artifact)
982
 
983
+ # save some space
984
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
985
+ c.cleanup(wandb.util.from_human_size("5GB"))
986
+
987
  # save to the hub
988
  if training_args.push_to_hub:
989
  model.save_pretrained(
 
991
  params=params,
992
  push_to_hub=training_args.push_to_hub,
993
  commit_message=f"Saving weights and logs of epoch {epoch+1}",
994
+ temp_dir=True, # avoid issues with being in a repository
995
  )
996
+
997
  for epoch in epochs:
998
  # ======================== Training ================================
999
+ wandb_log({"train/epoch": epoch}, step=global_step)
1000
 
1001
  # Create sampling rng
1002
  rng, input_rng = jax.random.split(rng)
1003
 
1004
  # Generate an epoch by shuffling sampling indices from the train dataset
1005
+ if data_args.streaming:
1006
+ train_dataset.set_epoch(epoch)
1007
+ train_loader = data_loader_streaming(train_dataset, train_batch_size)
1008
+ else:
1009
+ train_loader = data_loader(
1010
+ input_rng, train_dataset, train_batch_size, shuffle=True
1011
+ )
1012
  # train
1013
+ for batch in tqdm(
1014
+ train_loader,
1015
+ desc="Training...",
1016
+ position=1,
1017
+ leave=False,
1018
+ total=steps_per_epoch,
1019
+ ):
1020
+ global_step += 1
1021
  state, train_metric = p_train_step(state, batch)
1022
 
1023
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
1024
  # log metrics
1025
+ wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
1026
 
1027
+ if training_args.eval_steps and global_step % training_args.eval_steps == 0:
1028
  run_evaluation()
1029
+
1030
  if global_step % data_args.save_model_steps == 0:
1031
  run_save_model(state, global_step, epoch)
1032
+
1033
  # log final train metrics
1034
+ wandb_log(unreplicate(train_metric), step=global_step, prefix="train")
1035
 
 
1036
  train_metric = unreplicate(train_metric)
1037
  epochs.write(
1038
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
 
1045
  run_save_model(state, global_step, epoch, eval_metrics)
1046
 
1047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1048
  if __name__ == "__main__":
1049
  main()