boris commited on
Commit
4cb21dd
1 Parent(s): db5a22a

feat(train): simplify tokenizer loading

Browse files
Files changed (1) hide show
  1. tools/train/train.py +12 -10
tools/train/train.py CHANGED
@@ -55,7 +55,7 @@ from dalle_mini.model import (
55
  )
56
 
57
  cc.initialize_cache(
58
- "/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
59
  )
60
 
61
 
@@ -104,6 +104,11 @@ class ModelArguments:
104
  state_artifact: str = field(init=False)
105
 
106
  def __post_init__(self):
 
 
 
 
 
107
  if self.restore_state:
108
  assert self.model_name_or_path is not None and (
109
  "/model-" in self.model_name_or_path
@@ -511,15 +516,9 @@ def main():
511
  )
512
 
513
  # Load tokenizer
514
- if model_args.tokenizer_name is not None:
515
- tokenizer = DalleBartTokenizer.from_pretrained(
516
- model_args.tokenizer_name, use_fast=True
517
- )
518
- else:
519
- tokenizer = DalleBartTokenizer.from_pretrained(
520
- model_args.model_name_or_path,
521
- use_fast=True,
522
- )
523
 
524
  # get PartitionSpec for model params (required to be a dict)
525
  param_spec = set_partitions(model.params)
@@ -532,6 +531,9 @@ def main():
532
 
533
  dataset.preprocess(tokenizer=tokenizer, config=model.config)
534
 
 
 
 
535
  # Initialize our training
536
  dropout_rng = jax.random.PRNGKey(training_args.seed_model)
537
 
 
55
  )
56
 
57
  cc.initialize_cache(
58
+ "/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
59
  )
60
 
61
 
 
104
  state_artifact: str = field(init=False)
105
 
106
  def __post_init__(self):
107
+ if self.tokenizer_name is None:
108
+ self.tokenizer_name == self.model_name_or_path
109
+ assert (
110
+ self.tokenizer_name is not None
111
+ ), "Tokenizer name or model name/path needs to be specified"
112
  if self.restore_state:
113
  assert self.model_name_or_path is not None and (
114
  "/model-" in self.model_name_or_path
 
516
  )
517
 
518
  # Load tokenizer
519
+ tokenizer = DalleBartTokenizer.from_pretrained(
520
+ model_args.tokenizer_name, use_fast=True
521
+ )
 
 
 
 
 
 
522
 
523
  # get PartitionSpec for model params (required to be a dict)
524
  param_spec = set_partitions(model.params)
 
531
 
532
  dataset.preprocess(tokenizer=tokenizer, config=model.config)
533
 
534
+ # no dropout (hardcoded)
535
+ model.config.dropout = 0.0
536
+
537
  # Initialize our training
538
  dropout_rng = jax.random.PRNGKey(training_args.seed_model)
539