Spaces:
Running
Running
feat(train): no batch dimension with pjit
Browse files- src/dalle_mini/model/__init__.py +1 -0
- tools/train/train.py +2 -5
src/dalle_mini/model/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from .configuration import DalleBartConfig
|
2 |
from .modeling import DalleBart
|
|
|
3 |
from .tokenizer import DalleBartTokenizer
|
|
|
1 |
from .configuration import DalleBartConfig
|
2 |
from .modeling import DalleBart
|
3 |
+
from .partitions import set_partitions
|
4 |
from .tokenizer import DalleBartTokenizer
|
tools/train/train.py
CHANGED
@@ -38,7 +38,7 @@ from distributed_shampoo import GraftingType, distributed_shampoo
|
|
38 |
from flax.core.frozen_dict import freeze
|
39 |
from flax.serialization import from_bytes, to_bytes
|
40 |
from flax.training import train_state
|
41 |
-
from flax.training.common_utils import
|
42 |
from jax.experimental import PartitionSpec, maps
|
43 |
from jax.experimental.pjit import pjit
|
44 |
from tqdm import tqdm
|
@@ -764,7 +764,6 @@ def main():
|
|
764 |
),
|
765 |
)
|
766 |
|
767 |
-
grads = jax.lax.pmean(grads, "batch")
|
768 |
state = state.apply_gradients(
|
769 |
grads=grads,
|
770 |
dropout_rng=new_dropout_rng,
|
@@ -776,7 +775,6 @@ def main():
|
|
776 |
"loss": loss,
|
777 |
"learning_rate": learning_rate_fn(state.step),
|
778 |
}
|
779 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
780 |
|
781 |
return state, metrics
|
782 |
|
@@ -788,7 +786,6 @@ def main():
|
|
788 |
|
789 |
# summarize metrics
|
790 |
metrics = {"loss": loss}
|
791 |
-
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
792 |
return metrics
|
793 |
|
794 |
# Create parallel version of the train and eval step
|
@@ -861,7 +858,7 @@ def main():
|
|
861 |
eval_metrics.append(metrics)
|
862 |
|
863 |
# normalize eval metrics
|
864 |
-
eval_metrics =
|
865 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
866 |
|
867 |
# log metrics
|
|
|
38 |
from flax.core.frozen_dict import freeze
|
39 |
from flax.serialization import from_bytes, to_bytes
|
40 |
from flax.training import train_state
|
41 |
+
from flax.training.common_utils import onehot, stack_forest
|
42 |
from jax.experimental import PartitionSpec, maps
|
43 |
from jax.experimental.pjit import pjit
|
44 |
from tqdm import tqdm
|
|
|
764 |
),
|
765 |
)
|
766 |
|
|
|
767 |
state = state.apply_gradients(
|
768 |
grads=grads,
|
769 |
dropout_rng=new_dropout_rng,
|
|
|
775 |
"loss": loss,
|
776 |
"learning_rate": learning_rate_fn(state.step),
|
777 |
}
|
|
|
778 |
|
779 |
return state, metrics
|
780 |
|
|
|
786 |
|
787 |
# summarize metrics
|
788 |
metrics = {"loss": loss}
|
|
|
789 |
return metrics
|
790 |
|
791 |
# Create parallel version of the train and eval step
|
|
|
858 |
eval_metrics.append(metrics)
|
859 |
|
860 |
# normalize eval metrics
|
861 |
+
eval_metrics = stack_forest(eval_metrics)
|
862 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
863 |
|
864 |
# log metrics
|