Spaces:
Running
Running
fix(train): opt_state_shape for distributed_shampoo
Browse files- tools/train/train.py +23 -26
tools/train/train.py
CHANGED
@@ -105,7 +105,6 @@ class ModelArguments:
|
|
105 |
self.state_artifact = self.model_name_or_path.replace(
|
106 |
"/model-", "/state-", 1
|
107 |
)
|
108 |
-
raise ValueError("Need a dataset repository or path.")
|
109 |
|
110 |
|
111 |
@dataclass
|
@@ -648,30 +647,29 @@ def main():
|
|
648 |
|
649 |
# get PartitionSpec for optimizer state
|
650 |
def get_opt_state_spec_and_shape(param_spec):
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
)
|
671 |
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
|
676 |
elif training_args.optim == "distributed_shampoo":
|
677 |
opt_state_spec = opt_fn.pspec_fn(
|
@@ -679,7 +677,6 @@ def main():
|
|
679 |
params_partition_spec=param_spec,
|
680 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
681 |
)
|
682 |
-
opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
|
683 |
else:
|
684 |
raise NotImplementedError
|
685 |
return opt_state_spec, opt_state_shape
|
@@ -760,7 +757,7 @@ def main():
|
|
760 |
del opt_state
|
761 |
|
762 |
# free memory
|
763 |
-
del model._params
|
764 |
|
765 |
# define batch specs
|
766 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
|
|
105 |
self.state_artifact = self.model_name_or_path.replace(
|
106 |
"/model-", "/state-", 1
|
107 |
)
|
|
|
108 |
|
109 |
|
110 |
@dataclass
|
|
|
647 |
|
648 |
# get PartitionSpec for optimizer state
|
649 |
def get_opt_state_spec_and_shape(param_spec):
|
650 |
+
# get opt_state shape without actual init
|
651 |
+
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
652 |
+
|
653 |
+
if training_args.optim == "adam":
|
654 |
+
|
655 |
+
def _opt_state_spec_per_leaf(x):
|
656 |
+
if isinstance(x, FrozenDict):
|
657 |
+
# variables with same structure as params
|
658 |
+
return param_spec
|
659 |
+
else:
|
660 |
+
# other variables such as count
|
661 |
+
return None
|
662 |
+
|
663 |
+
opt_state_spec = jax.tree_map(
|
664 |
+
_opt_state_spec_per_leaf,
|
665 |
+
opt_state_shape,
|
666 |
+
# return None spec for empty elements
|
667 |
+
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
|
668 |
+
)
|
|
|
669 |
|
670 |
+
elif training_args.optim == "adafactor":
|
671 |
+
# factorized state must be replicated (rank different than params)
|
672 |
+
opt_state_spec = None
|
673 |
|
674 |
elif training_args.optim == "distributed_shampoo":
|
675 |
opt_state_spec = opt_fn.pspec_fn(
|
|
|
677 |
params_partition_spec=param_spec,
|
678 |
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
679 |
)
|
|
|
680 |
else:
|
681 |
raise NotImplementedError
|
682 |
return opt_state_spec, opt_state_shape
|
|
|
757 |
del opt_state
|
758 |
|
759 |
# free memory
|
760 |
+
del model._params, opt_state_spec, opt_state_shape
|
761 |
|
762 |
# define batch specs
|
763 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|