boris commited on
Commit
3cd6d41
1 Parent(s): a253eea

feat: cleanup training script

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +26 -65
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -23,21 +23,19 @@ import os
23
  import logging as pylogging # To avoid collision with transformers.utils.logging
24
  import sys
25
  from dataclasses import dataclass, field
26
- from functools import partial
27
  from pathlib import Path
28
  from typing import Callable, Optional
29
  import json
30
 
31
  import datasets
32
  import numpy as np
33
- from datasets import Dataset, load_dataset, load_metric
34
  from tqdm import tqdm
35
 
36
  import jax
37
  import jax.numpy as jnp
38
  import optax
39
  import transformers
40
- from filelock import FileLock
41
  from flax import jax_utils, traverse_util
42
  from flax.serialization import from_bytes, to_bytes
43
  import flax.linen as nn
@@ -45,15 +43,12 @@ from flax.jax_utils import unreplicate
45
  from flax.training import train_state
46
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
  from transformers import (
48
- FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
49
  AutoTokenizer,
50
- FlaxAutoModelForSeq2SeqLM,
51
  FlaxBartForConditionalGeneration,
52
  HfArgumentParser,
53
  TrainingArguments,
54
  )
55
  from transformers.models.bart.modeling_flax_bart import *
56
- from transformers.file_utils import is_offline_mode
57
 
58
  import wandb
59
 
@@ -62,10 +57,6 @@ from dalle_mini.text import TextNormalizer
62
  logger = pylogging.getLogger(__name__)
63
 
64
 
65
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
66
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
67
-
68
-
69
  # Model hyperparameters, for convenience
70
  # TODO: the model has now it's own definition file and should be imported
71
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
@@ -87,25 +78,12 @@ class ModelArguments:
87
  "Don't set if you want to train a model from scratch."
88
  },
89
  )
90
- model_type: Optional[str] = field(
91
- default=None,
92
- metadata={
93
- "help": "If training from scratch, pass a model type from the list: "
94
- + ", ".join(MODEL_TYPES)
95
- },
96
- )
97
  config_name: Optional[str] = field(
98
  default=None,
99
  metadata={
100
  "help": "Pretrained config name or path if not the same as model_name"
101
  },
102
  )
103
- cache_dir: Optional[str] = field(
104
- default=None,
105
- metadata={
106
- "help": "Where do you want to store the pretrained models downloaded from s3"
107
- },
108
- )
109
  use_fast_tokenizer: bool = field(
110
  default=True,
111
  metadata={
@@ -281,6 +259,19 @@ class TrainState(train_state.TrainState):
281
  dropout_rng=shard_prng_key(self.dropout_rng)
282
  )
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  class CustomFlaxBartModule(FlaxBartModule):
286
  def setup(self):
@@ -480,22 +471,6 @@ def main():
480
  streaming=data_args.streaming,
481
  )
482
 
483
- # Set up items to load or create
484
- tokenizer = None
485
- artifact_dir = None
486
-
487
- def restore_state(state, artifact_dir):
488
- # restore optimizer state
489
- with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
490
- opt_state = from_bytes(state.opt_state, f.read())
491
-
492
- # restore steps
493
- with (Path(artifact_dir) / "training_state.json").open("r") as f:
494
- training_state = json.load(f)
495
- step = training_state["step"]
496
-
497
- return step, opt_state
498
-
499
  # Set up wandb run
500
  wandb.init(
501
  entity="dalle-mini",
@@ -510,22 +485,11 @@ def main():
510
  artifact_dir = artifact.download()
511
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
512
 
513
- # some models will try to change bos (because of force_bos_token_to_be_generated)
514
- # we ensure bos and eos are not forced
515
- model.config.force_bos_token_to_be_generated = False
516
- model.config.forced_bos_token_id = None
517
- model.config.forced_eos_token_id = None
518
-
519
- # used in the preprocessing function
520
- config = model.config
521
-
522
- # load tokenizer if present
523
- if (Path(artifact_dir) / "tokenizer_config.json").exists():
524
- tokenizer = AutoTokenizer.from_pretrained(
525
- model_args.model_name_or_path,
526
- cache_dir=model_args.cache_dir,
527
- use_fast=model_args.use_fast_tokenizer,
528
- )
529
 
530
  else:
531
  # Set up our new model config
@@ -552,11 +516,9 @@ def main():
552
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
553
  )
554
 
555
- # Load tokenizer if it has not been set
556
- if tokenizer is None:
557
  tokenizer = AutoTokenizer.from_pretrained(
558
  model_args.model_name_or_path,
559
- cache_dir=model_args.cache_dir,
560
  use_fast=model_args.use_fast_tokenizer,
561
  )
562
 
@@ -609,7 +571,9 @@ def main():
609
  model_inputs["labels"] = labels
610
 
611
  # In our case, this prepends the bos token and removes the last one
612
- decoder_input_ids = shift_tokens_right(labels, config.decoder_start_token_id)
 
 
613
  model_inputs["decoder_input_ids"] = decoder_input_ids
614
 
615
  return model_inputs
@@ -787,8 +751,7 @@ def main():
787
  )
788
  if model_args.from_checkpoint is not None:
789
  # restore optimizer state and step
790
- step, opt_state = restore_state(state, artifact_dir)
791
- state = state.replace(step=step, opt_state=opt_state)
792
  # TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
793
 
794
  # label smoothed cross entropy
@@ -974,16 +937,14 @@ def main():
974
  for epoch in epochs:
975
  # ======================== Training ================================
976
  step = unreplicate(state.step)
977
- wandb_log({"train/epoch": epoch}, step=step)
978
-
979
- # Create sampling rng
980
- rng, input_rng = jax.random.split(rng)
981
 
982
  # Generate an epoch by shuffling sampling indices from the train dataset
983
  if data_args.streaming:
984
  train_dataset.set_epoch(epoch)
985
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
986
  else:
 
987
  train_loader = data_loader(
988
  input_rng, train_dataset, train_batch_size, shuffle=True
989
  )
 
23
  import logging as pylogging # To avoid collision with transformers.utils.logging
24
  import sys
25
  from dataclasses import dataclass, field
 
26
  from pathlib import Path
27
  from typing import Callable, Optional
28
  import json
29
 
30
  import datasets
31
  import numpy as np
32
+ from datasets import Dataset, load_dataset
33
  from tqdm import tqdm
34
 
35
  import jax
36
  import jax.numpy as jnp
37
  import optax
38
  import transformers
 
39
  from flax import jax_utils, traverse_util
40
  from flax.serialization import from_bytes, to_bytes
41
  import flax.linen as nn
 
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
45
  from transformers import (
 
46
  AutoTokenizer,
 
47
  FlaxBartForConditionalGeneration,
48
  HfArgumentParser,
49
  TrainingArguments,
50
  )
51
  from transformers.models.bart.modeling_flax_bart import *
 
52
 
53
  import wandb
54
 
 
57
  logger = pylogging.getLogger(__name__)
58
 
59
 
 
 
 
 
60
  # Model hyperparameters, for convenience
61
  # TODO: the model has now it's own definition file and should be imported
62
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
 
78
  "Don't set if you want to train a model from scratch."
79
  },
80
  )
 
 
 
 
 
 
 
81
  config_name: Optional[str] = field(
82
  default=None,
83
  metadata={
84
  "help": "Pretrained config name or path if not the same as model_name"
85
  },
86
  )
 
 
 
 
 
 
87
  use_fast_tokenizer: bool = field(
88
  default=True,
89
  metadata={
 
259
  dropout_rng=shard_prng_key(self.dropout_rng)
260
  )
261
 
262
+ def restore_state(self, artifact_dir):
263
+ # restore optimizer state
264
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
265
+ opt_state = from_bytes(self.opt_state, f.read())
266
+
267
+ # restore steps
268
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
269
+ training_state = json.load(f)
270
+ step = training_state["step"]
271
+
272
+ # replace state
273
+ return self.replace(step=step, opt_state=opt_state)
274
+
275
 
276
  class CustomFlaxBartModule(FlaxBartModule):
277
  def setup(self):
 
471
  streaming=data_args.streaming,
472
  )
473
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # Set up wandb run
475
  wandb.init(
476
  entity="dalle-mini",
 
485
  artifact_dir = artifact.download()
486
  model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
487
 
488
+ # load tokenizer
489
+ tokenizer = AutoTokenizer.from_pretrained(
490
+ artifact_dir,
491
+ use_fast=model_args.use_fast_tokenizer,
492
+ )
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  else:
495
  # Set up our new model config
 
516
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
517
  )
518
 
519
+ # Load tokenizer
 
520
  tokenizer = AutoTokenizer.from_pretrained(
521
  model_args.model_name_or_path,
 
522
  use_fast=model_args.use_fast_tokenizer,
523
  )
524
 
 
571
  model_inputs["labels"] = labels
572
 
573
  # In our case, this prepends the bos token and removes the last one
574
+ decoder_input_ids = shift_tokens_right(
575
+ labels, model.config.decoder_start_token_id
576
+ )
577
  model_inputs["decoder_input_ids"] = decoder_input_ids
578
 
579
  return model_inputs
 
751
  )
752
  if model_args.from_checkpoint is not None:
753
  # restore optimizer state and step
754
+ state = state.restore_state(artifact_dir)
 
755
  # TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
756
 
757
  # label smoothed cross entropy
 
937
  for epoch in epochs:
938
  # ======================== Training ================================
939
  step = unreplicate(state.step)
940
+ # wandb_log({"train/epoch": epoch}, step=step)
 
 
 
941
 
942
  # Generate an epoch by shuffling sampling indices from the train dataset
943
  if data_args.streaming:
944
  train_dataset.set_epoch(epoch)
945
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
946
  else:
947
+ rng, input_rng = jax.random.split(rng)
948
  train_loader = data_loader(
949
  input_rng, train_dataset, train_batch_size, shuffle=True
950
  )