boris commited on
Commit
225b6ff
1 Parent(s): fa5b058

fix(train): opt_state_shape for distributed_shampoo

Browse files
Files changed (1) hide show
  1. tools/train/train.py +23 -26
tools/train/train.py CHANGED
@@ -105,7 +105,6 @@ class ModelArguments:
105
  self.state_artifact = self.model_name_or_path.replace(
106
  "/model-", "/state-", 1
107
  )
108
- raise ValueError("Need a dataset repository or path.")
109
 
110
 
111
  @dataclass
@@ -648,30 +647,29 @@ def main():
648
 
649
  # get PartitionSpec for optimizer state
650
  def get_opt_state_spec_and_shape(param_spec):
651
- if training_args.optim in ["adam", "adafactor"]:
652
- # get opt_state shape without actual init
653
- opt_state_shape = jax.eval_shape(optimizer.init, model.params)
654
-
655
- if training_args.optim == "adam":
656
-
657
- def _opt_state_spec_per_leaf(x):
658
- if isinstance(x, FrozenDict):
659
- # variables with same structure as params
660
- return param_spec
661
- else:
662
- # other variables such as count
663
- return None
664
-
665
- opt_state_spec = jax.tree_map(
666
- _opt_state_spec_per_leaf,
667
- opt_state_shape,
668
- # return None spec for empty elements
669
- is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
670
- )
671
 
672
- elif training_args.optim == "adafactor":
673
- # factorized state must be replicated (rank different than params)
674
- opt_state_spec = None
675
 
676
  elif training_args.optim == "distributed_shampoo":
677
  opt_state_spec = opt_fn.pspec_fn(
@@ -679,7 +677,6 @@ def main():
679
  params_partition_spec=param_spec,
680
  partition_spec_for_statistics=PartitionSpec(None, "batch", None),
681
  )
682
- opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
683
  else:
684
  raise NotImplementedError
685
  return opt_state_spec, opt_state_shape
@@ -760,7 +757,7 @@ def main():
760
  del opt_state
761
 
762
  # free memory
763
- del model._params
764
 
765
  # define batch specs
766
  keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
 
105
  self.state_artifact = self.model_name_or_path.replace(
106
  "/model-", "/state-", 1
107
  )
 
108
 
109
 
110
  @dataclass
 
647
 
648
  # get PartitionSpec for optimizer state
649
  def get_opt_state_spec_and_shape(param_spec):
650
+ # get opt_state shape without actual init
651
+ opt_state_shape = jax.eval_shape(optimizer.init, model.params)
652
+
653
+ if training_args.optim == "adam":
654
+
655
+ def _opt_state_spec_per_leaf(x):
656
+ if isinstance(x, FrozenDict):
657
+ # variables with same structure as params
658
+ return param_spec
659
+ else:
660
+ # other variables such as count
661
+ return None
662
+
663
+ opt_state_spec = jax.tree_map(
664
+ _opt_state_spec_per_leaf,
665
+ opt_state_shape,
666
+ # return None spec for empty elements
667
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
668
+ )
 
669
 
670
+ elif training_args.optim == "adafactor":
671
+ # factorized state must be replicated (rank different than params)
672
+ opt_state_spec = None
673
 
674
  elif training_args.optim == "distributed_shampoo":
675
  opt_state_spec = opt_fn.pspec_fn(
 
677
  params_partition_spec=param_spec,
678
  partition_spec_for_statistics=PartitionSpec(None, "batch", None),
679
  )
 
680
  else:
681
  raise NotImplementedError
682
  return opt_state_spec, opt_state_shape
 
757
  del opt_state
758
 
759
  # free memory
760
+ del model._params, opt_state_spec, opt_state_shape
761
 
762
  # define batch specs
763
  keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]