Spaces:
Running
Running
feat(train): improve pjit speed
Browse files- src/dalle_mini/data.py +7 -28
- tools/train/train.py +78 -42
src/dalle_mini/data.py
CHANGED
@@ -152,14 +152,7 @@ class Dataset:
|
|
152 |
),
|
153 |
)
|
154 |
|
155 |
-
def dataloader(
|
156 |
-
self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
|
157 |
-
):
|
158 |
-
num_devices = jax.local_device_count()
|
159 |
-
total_batch_size = per_device_batch_size * num_devices
|
160 |
-
if gradient_accumulation_steps is not None:
|
161 |
-
total_batch_size *= gradient_accumulation_steps
|
162 |
-
|
163 |
def _dataloader_datasets_non_streaming(
|
164 |
dataset: Dataset,
|
165 |
rng: jax.random.PRNGKey = None,
|
@@ -168,7 +161,7 @@ class Dataset:
|
|
168 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
169 |
Shuffle batches if rng is set.
|
170 |
"""
|
171 |
-
steps_per_epoch = len(dataset) //
|
172 |
|
173 |
if rng is not None:
|
174 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
@@ -176,20 +169,13 @@ class Dataset:
|
|
176 |
batch_idx = jnp.arange(len(dataset))
|
177 |
|
178 |
batch_idx = batch_idx[
|
179 |
-
: steps_per_epoch *
|
180 |
] # Skip incomplete batch.
|
181 |
-
batch_idx = batch_idx.reshape((steps_per_epoch,
|
182 |
|
183 |
for idx in batch_idx:
|
184 |
batch = dataset[idx]
|
185 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
186 |
-
if gradient_accumulation_steps is not None:
|
187 |
-
batch = jax.tree_map(
|
188 |
-
lambda x: x.reshape(
|
189 |
-
(gradient_accumulation_steps, -1) + x.shape[1:]
|
190 |
-
),
|
191 |
-
batch,
|
192 |
-
)
|
193 |
yield batch
|
194 |
|
195 |
def _dataloader_datasets_streaming(
|
@@ -205,22 +191,15 @@ class Dataset:
|
|
205 |
# For validation data we put the entire set on each host as we could lose
|
206 |
# too many samples on pods
|
207 |
if epoch is not None:
|
208 |
-
|
|
|
209 |
dataset.set_epoch(epoch)
|
210 |
epoch += 1
|
211 |
for item in dataset:
|
212 |
for k, v in item.items():
|
213 |
batch[k].append(v)
|
214 |
-
if len(batch[keys[0]]) ==
|
215 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
216 |
-
if gradient_accumulation_steps is not None:
|
217 |
-
# training mode
|
218 |
-
batch = jax.tree_map(
|
219 |
-
lambda x: x.reshape(
|
220 |
-
(gradient_accumulation_steps, -1) + x.shape[1:]
|
221 |
-
),
|
222 |
-
batch,
|
223 |
-
)
|
224 |
yield batch
|
225 |
batch = {k: [] for k in keys}
|
226 |
first_loop = False
|
|
|
152 |
),
|
153 |
)
|
154 |
|
155 |
+
def dataloader(self, split, batch_size, epoch=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def _dataloader_datasets_non_streaming(
|
157 |
dataset: Dataset,
|
158 |
rng: jax.random.PRNGKey = None,
|
|
|
161 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
162 |
Shuffle batches if rng is set.
|
163 |
"""
|
164 |
+
steps_per_epoch = len(dataset) // batch_size
|
165 |
|
166 |
if rng is not None:
|
167 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
169 |
batch_idx = jnp.arange(len(dataset))
|
170 |
|
171 |
batch_idx = batch_idx[
|
172 |
+
: steps_per_epoch * batch_size
|
173 |
] # Skip incomplete batch.
|
174 |
+
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
175 |
|
176 |
for idx in batch_idx:
|
177 |
batch = dataset[idx]
|
178 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
yield batch
|
180 |
|
181 |
def _dataloader_datasets_streaming(
|
|
|
191 |
# For validation data we put the entire set on each host as we could lose
|
192 |
# too many samples on pods
|
193 |
if epoch is not None:
|
194 |
+
assert split == "train"
|
195 |
+
# reshuffle training data at each epoch
|
196 |
dataset.set_epoch(epoch)
|
197 |
epoch += 1
|
198 |
for item in dataset:
|
199 |
for k, v in item.items():
|
200 |
batch[k].append(v)
|
201 |
+
if len(batch[keys[0]]) == batch_size:
|
202 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
yield batch
|
204 |
batch = {k: [] for k in keys}
|
205 |
first_loop = False
|
tools/train/train.py
CHANGED
@@ -36,12 +36,12 @@ import transformers
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
-
from flax.core.frozen_dict import FrozenDict, freeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
43 |
from jax.experimental import PartitionSpec, maps
|
44 |
-
from jax.experimental.pjit import pjit
|
45 |
from tqdm import tqdm
|
46 |
from transformers import HfArgumentParser
|
47 |
|
@@ -551,12 +551,12 @@ def main():
|
|
551 |
num_epochs = training_args.num_train_epochs
|
552 |
# batch size
|
553 |
minibatch_size = (
|
554 |
-
training_args.per_device_train_batch_size *
|
555 |
)
|
556 |
batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
|
557 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
558 |
eval_batch_size = (
|
559 |
-
training_args.per_device_eval_batch_size *
|
560 |
)
|
561 |
len_train_dataset, len_eval_dataset = dataset.length
|
562 |
steps_per_epoch = (
|
@@ -762,6 +762,10 @@ def main():
|
|
762 |
# free memory
|
763 |
del model._params
|
764 |
|
|
|
|
|
|
|
|
|
765 |
# label smoothed cross entropy
|
766 |
def loss_fn(logits, labels):
|
767 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
@@ -771,16 +775,18 @@ def main():
|
|
771 |
# Define gradient update step fn
|
772 |
def train_step(state, batch, delta_time):
|
773 |
# check correct batch shape during compilation
|
774 |
-
assert batch["labels"].shape[0:
|
|
|
775 |
training_args.gradient_accumulation_steps,
|
776 |
-
|
777 |
-
), f"Expected label batch of shape
|
778 |
# create a new rng
|
779 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
780 |
# use a different rng per node
|
781 |
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
782 |
|
783 |
def compute_loss(params, minibatch):
|
|
|
784 |
labels = minibatch.pop("labels")
|
785 |
logits = state.apply_fn(
|
786 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
@@ -789,32 +795,52 @@ def main():
|
|
789 |
|
790 |
grad_fn = jax.value_and_grad(compute_loss)
|
791 |
|
792 |
-
|
793 |
-
|
794 |
-
loss, grads = grad_fn(state.params, minibatch)
|
795 |
-
else:
|
796 |
|
797 |
-
|
798 |
-
minibatch = jax.tree_map(
|
799 |
-
|
800 |
-
|
801 |
-
cumul_loss_grads,
|
802 |
-
grad_fn(state.params, minibatch),
|
803 |
)
|
|
|
|
|
804 |
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
818 |
|
819 |
state = state.apply_gradients(
|
820 |
grads=grads,
|
@@ -832,6 +858,7 @@ def main():
|
|
832 |
|
833 |
# Define eval fn
|
834 |
def eval_step(params, batch):
|
|
|
835 |
labels = batch.pop("labels")
|
836 |
logits = model(**batch, params=params, train=False)[0]
|
837 |
loss = loss_fn(logits, labels)
|
@@ -843,13 +870,13 @@ def main():
|
|
843 |
# Create parallel version of the train and eval step
|
844 |
p_train_step = pjit(
|
845 |
train_step,
|
846 |
-
in_axis_resources=(state_spec,
|
847 |
out_axis_resources=(state_spec, None),
|
848 |
donate_argnums=(0,),
|
849 |
)
|
850 |
p_eval_step = pjit(
|
851 |
eval_step,
|
852 |
-
in_axis_resources=(param_spec,
|
853 |
out_axis_resources=None,
|
854 |
)
|
855 |
|
@@ -890,9 +917,7 @@ def main():
|
|
890 |
# ======================== Evaluating ==============================
|
891 |
eval_metrics = []
|
892 |
if training_args.do_eval:
|
893 |
-
eval_loader = dataset.dataloader(
|
894 |
-
"eval", training_args.per_device_eval_batch_size
|
895 |
-
)
|
896 |
eval_steps = (
|
897 |
len_eval_dataset // eval_batch_size
|
898 |
if len_eval_dataset is not None
|
@@ -905,8 +930,8 @@ def main():
|
|
905 |
leave=False,
|
906 |
total=eval_steps,
|
907 |
):
|
908 |
-
#
|
909 |
-
metrics = p_eval_step(state.params, batch)
|
910 |
eval_metrics.append(metrics)
|
911 |
|
912 |
# normalize eval metrics
|
@@ -1010,8 +1035,7 @@ def main():
|
|
1010 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
1011 |
train_loader = dataset.dataloader(
|
1012 |
"train",
|
1013 |
-
|
1014 |
-
training_args.gradient_accumulation_steps,
|
1015 |
epoch,
|
1016 |
)
|
1017 |
# train
|
@@ -1022,15 +1046,27 @@ def main():
|
|
1022 |
leave=False,
|
1023 |
total=steps_per_epoch,
|
1024 |
):
|
1025 |
-
|
1026 |
# calculate delta time (we have a lag of one step but it's ok)
|
1027 |
new_time = time.perf_counter()
|
1028 |
delta_time = new_time - last_time
|
1029 |
last_time = new_time
|
1030 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1031 |
# train step
|
1032 |
-
state, train_metrics = p_train_step(state, batch, delta_time)
|
1033 |
-
step = state.step
|
1034 |
|
1035 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
1036 |
all_metrics = metrics_logger.get_all_train_metrics(
|
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
43 |
from jax.experimental import PartitionSpec, maps
|
44 |
+
from jax.experimental.pjit import pjit, with_sharding_constraint
|
45 |
from tqdm import tqdm
|
46 |
from transformers import HfArgumentParser
|
47 |
|
|
|
551 |
num_epochs = training_args.num_train_epochs
|
552 |
# batch size
|
553 |
minibatch_size = (
|
554 |
+
training_args.per_device_train_batch_size * training_args.dp_devices
|
555 |
)
|
556 |
batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
|
557 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
558 |
eval_batch_size = (
|
559 |
+
training_args.per_device_eval_batch_size * training_args.dp_devices
|
560 |
)
|
561 |
len_train_dataset, len_eval_dataset = dataset.length
|
562 |
steps_per_epoch = (
|
|
|
762 |
# free memory
|
763 |
del model._params
|
764 |
|
765 |
+
# define batch specs
|
766 |
+
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
767 |
+
batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
|
768 |
+
|
769 |
# label smoothed cross entropy
|
770 |
def loss_fn(logits, labels):
|
771 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
|
|
775 |
# Define gradient update step fn
|
776 |
def train_step(state, batch, delta_time):
|
777 |
# check correct batch shape during compilation
|
778 |
+
assert batch["labels"].shape[0:3] == (
|
779 |
+
training_args.dp_devices,
|
780 |
training_args.gradient_accumulation_steps,
|
781 |
+
training_args.per_device_train_batch_size,
|
782 |
+
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
783 |
# create a new rng
|
784 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
785 |
# use a different rng per node
|
786 |
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
787 |
|
788 |
def compute_loss(params, minibatch):
|
789 |
+
minibatch = unfreeze(minibatch)
|
790 |
labels = minibatch.pop("labels")
|
791 |
logits = state.apply_fn(
|
792 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
|
|
795 |
|
796 |
grad_fn = jax.value_and_grad(compute_loss)
|
797 |
|
798 |
+
def loss_grad_per_device(device_batch):
|
799 |
+
# device_batch has format (gradient_accumulation_steps, batch_size, ...)
|
|
|
|
|
800 |
|
801 |
+
if training_args.gradient_accumulation_steps == 1:
|
802 |
+
minibatch = jax.tree_map(
|
803 |
+
lambda x: x[0],
|
804 |
+
device_batch,
|
|
|
|
|
805 |
)
|
806 |
+
loss, grads = grad_fn(state.params, minibatch)
|
807 |
+
else:
|
808 |
|
809 |
+
def _cumul_loss_grads(i, cumul_loss_grads):
|
810 |
+
minibatch = jax.tree_map(
|
811 |
+
lambda x: x[i],
|
812 |
+
device_batch,
|
813 |
+
)
|
814 |
+
return jax.tree_map(
|
815 |
+
lambda x, y: x + y,
|
816 |
+
cumul_loss_grads,
|
817 |
+
grad_fn(state.params, minibatch),
|
818 |
+
)
|
819 |
+
|
820 |
+
init_loss_grads = (
|
821 |
+
0.0,
|
822 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
823 |
+
)
|
824 |
+
loss, grads = jax.tree_map(
|
825 |
+
lambda x: x / training_args.gradient_accumulation_steps,
|
826 |
+
jax.lax.fori_loop(
|
827 |
+
0,
|
828 |
+
training_args.gradient_accumulation_steps,
|
829 |
+
_cumul_loss_grads,
|
830 |
+
init_loss_grads,
|
831 |
+
),
|
832 |
+
)
|
833 |
+
return loss, grads
|
834 |
+
|
835 |
+
# calculate loss, grads per dp device
|
836 |
+
# batch has shape (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
|
837 |
+
loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
|
838 |
+
# enforce sharding constraints to avoid OOM
|
839 |
+
loss = with_sharding_constraint(loss, PartitionSpec("batch"))
|
840 |
+
grads = with_sharding_constraint(grads, PartitionSpec("batch"))
|
841 |
+
# calculate the mean over all devices
|
842 |
+
loss = jnp.mean(loss)
|
843 |
+
grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)
|
844 |
|
845 |
state = state.apply_gradients(
|
846 |
grads=grads,
|
|
|
858 |
|
859 |
# Define eval fn
|
860 |
def eval_step(params, batch):
|
861 |
+
batch = unfreeze(batch)
|
862 |
labels = batch.pop("labels")
|
863 |
logits = model(**batch, params=params, train=False)[0]
|
864 |
loss = loss_fn(logits, labels)
|
|
|
870 |
# Create parallel version of the train and eval step
|
871 |
p_train_step = pjit(
|
872 |
train_step,
|
873 |
+
in_axis_resources=(state_spec, batch_spec, None),
|
874 |
out_axis_resources=(state_spec, None),
|
875 |
donate_argnums=(0,),
|
876 |
)
|
877 |
p_eval_step = pjit(
|
878 |
eval_step,
|
879 |
+
in_axis_resources=(param_spec, batch_spec),
|
880 |
out_axis_resources=None,
|
881 |
)
|
882 |
|
|
|
917 |
# ======================== Evaluating ==============================
|
918 |
eval_metrics = []
|
919 |
if training_args.do_eval:
|
920 |
+
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
|
|
|
|
921 |
eval_steps = (
|
922 |
len_eval_dataset // eval_batch_size
|
923 |
if len_eval_dataset is not None
|
|
|
930 |
leave=False,
|
931 |
total=eval_steps,
|
932 |
):
|
933 |
+
# TODO: make this more efficient once training loop is fast
|
934 |
+
metrics = p_eval_step(state.params, freeze(batch))
|
935 |
eval_metrics.append(metrics)
|
936 |
|
937 |
# normalize eval metrics
|
|
|
1035 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
1036 |
train_loader = dataset.dataloader(
|
1037 |
"train",
|
1038 |
+
batch_size_per_node,
|
|
|
1039 |
epoch,
|
1040 |
)
|
1041 |
# train
|
|
|
1046 |
leave=False,
|
1047 |
total=steps_per_epoch,
|
1048 |
):
|
|
|
1049 |
# calculate delta time (we have a lag of one step but it's ok)
|
1050 |
new_time = time.perf_counter()
|
1051 |
delta_time = new_time - last_time
|
1052 |
last_time = new_time
|
1053 |
|
1054 |
+
# reshape data into (dp_devices, gradient_accumulation_steps, batch_per_dp_device, ...)
|
1055 |
+
batch = jax.tree_map(
|
1056 |
+
lambda x: x.reshape(
|
1057 |
+
(
|
1058 |
+
training_args.dp_devices,
|
1059 |
+
training_args.gradient_accumulation_steps,
|
1060 |
+
-1,
|
1061 |
+
)
|
1062 |
+
+ x.shape[1:]
|
1063 |
+
),
|
1064 |
+
batch,
|
1065 |
+
)
|
1066 |
+
|
1067 |
# train step
|
1068 |
+
state, train_metrics = p_train_step(state, freeze(batch), delta_time)
|
1069 |
+
step = int(state.step)
|
1070 |
|
1071 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
1072 |
all_metrics = metrics_logger.get_all_train_metrics(
|