Spaces:
Running
Running
feat(pjit): follow t5x style
Browse files- tools/train/train.py +65 -58
tools/train/train.py
CHANGED
@@ -765,6 +765,7 @@ def main():
|
|
765 |
# define batch specs
|
766 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
767 |
batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
|
|
|
768 |
|
769 |
# label smoothed cross entropy
|
770 |
def loss_fn(logits, labels):
|
@@ -774,18 +775,22 @@ def main():
|
|
774 |
|
775 |
# Define gradient update step fn
|
776 |
def train_step(state, batch, delta_time):
|
|
|
777 |
# check correct batch shape during compilation
|
778 |
-
assert batch["labels"].shape[0:
|
779 |
-
training_args.dp_devices,
|
780 |
training_args.gradient_accumulation_steps,
|
781 |
-
|
782 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
783 |
-
# create a new rng
|
784 |
-
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
785 |
-
# use a different rng per node
|
786 |
-
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
787 |
|
788 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
minibatch = unfreeze(minibatch)
|
790 |
labels = minibatch.pop("labels")
|
791 |
logits = state.apply_fn(
|
@@ -795,58 +800,61 @@ def main():
|
|
795 |
|
796 |
grad_fn = jax.value_and_grad(compute_loss)
|
797 |
|
798 |
-
def
|
799 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
800 |
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
)
|
806 |
-
loss, grads = grad_fn(state.params, minibatch)
|
807 |
-
else:
|
808 |
|
809 |
-
|
810 |
-
minibatch = jax.tree_map(
|
811 |
-
lambda x: x[i],
|
812 |
-
device_batch,
|
813 |
-
)
|
814 |
-
return jax.tree_map(
|
815 |
-
lambda x, y: x + y,
|
816 |
-
cumul_loss_grads,
|
817 |
-
grad_fn(state.params, minibatch),
|
818 |
-
)
|
819 |
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
|
|
|
|
|
847 |
state = state.apply_gradients(
|
848 |
grads=grads,
|
849 |
-
dropout_rng=
|
850 |
train_time=state.train_time + delta_time,
|
851 |
train_samples=state.train_samples + batch_size_per_step,
|
852 |
)
|
@@ -872,7 +880,7 @@ def main():
|
|
872 |
# Create parallel version of the train and eval step
|
873 |
p_train_step = pjit(
|
874 |
train_step,
|
875 |
-
in_axis_resources=(state_spec,
|
876 |
out_axis_resources=(state_spec, None),
|
877 |
donate_argnums=(0,),
|
878 |
)
|
@@ -1053,13 +1061,12 @@ def main():
|
|
1053 |
delta_time = new_time - last_time
|
1054 |
last_time = new_time
|
1055 |
|
1056 |
-
# reshape data into (
|
1057 |
batch = jax.tree_map(
|
1058 |
lambda x: x.reshape(
|
1059 |
(
|
1060 |
-
training_args.dp_devices,
|
1061 |
training_args.gradient_accumulation_steps,
|
1062 |
-
|
1063 |
)
|
1064 |
+ x.shape[1:]
|
1065 |
),
|
|
|
765 |
# define batch specs
|
766 |
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
767 |
batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
|
768 |
+
grad_batch_spec = freeze({k: PartitionSpec(None, "batch") for k in keys})
|
769 |
|
770 |
# label smoothed cross entropy
|
771 |
def loss_fn(logits, labels):
|
|
|
775 |
|
776 |
# Define gradient update step fn
|
777 |
def train_step(state, batch, delta_time):
|
778 |
+
# batch is (gradient_accumulation_steps, minibatch_size, ...)
|
779 |
# check correct batch shape during compilation
|
780 |
+
assert batch["labels"].shape[0:2] == (
|
|
|
781 |
training_args.gradient_accumulation_steps,
|
782 |
+
minibatch_size,
|
783 |
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
|
|
|
|
|
|
|
|
784 |
|
785 |
+
# get a minibatch (one gradient accumulation slice)
|
786 |
+
def get_minibatch(batch, grad_idx):
|
787 |
+
return jax.tree_map(
|
788 |
+
lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
|
789 |
+
batch,
|
790 |
+
)
|
791 |
+
|
792 |
+
def compute_loss(params, minibatch, dropout_rng):
|
793 |
+
# minibatch has dim (batch_size, ...)
|
794 |
minibatch = unfreeze(minibatch)
|
795 |
labels = minibatch.pop("labels")
|
796 |
logits = state.apply_fn(
|
|
|
800 |
|
801 |
grad_fn = jax.value_and_grad(compute_loss)
|
802 |
|
803 |
+
def loss_and_grad(grad_idx, dropout_rng):
|
804 |
+
minibatch = get_minibatch(batch, grad_idx)
|
805 |
+
# ensure batch is sharded over devices
|
806 |
+
minibatch = jax.tree_map(
|
807 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
|
808 |
+
)
|
809 |
+
# return loss and grads
|
810 |
+
return grad_fn(state.params, minibatch, dropout_rng)
|
811 |
|
812 |
+
# create a new rng
|
813 |
+
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
814 |
+
# use a different rng per node
|
815 |
+
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
|
|
|
|
|
|
816 |
|
817 |
+
if training_args.gradient_accumulation_steps == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
818 |
|
819 |
+
def batch_step(dropout_rng):
|
820 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
821 |
+
loss_grad = loss_and_grad(0, dropout_rng)
|
822 |
+
return loss_grad, new_dropout_rng
|
823 |
+
|
824 |
+
loss_grad, dropout_rng = batch_step(dropout_rng)
|
825 |
+
else:
|
826 |
+
# create initial state for per_minibatch_step loop
|
827 |
+
init_cumul_loss_grad = (
|
828 |
+
0.0,
|
829 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
830 |
+
)
|
831 |
+
init_minibatch_step = (init_cumul_loss_grad, dropout_rng)
|
832 |
+
|
833 |
+
# accumulate gradients
|
834 |
+
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
835 |
+
cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
|
836 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
837 |
+
loss_grad = loss_and_grad(grad_idx, dropout_rng)
|
838 |
+
cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
|
839 |
+
return cumul_loss_grad, new_dropout_rng
|
840 |
+
|
841 |
+
# loop over gradients
|
842 |
+
loss_grad, dropout_rng = jax.lax.fori_loop(
|
843 |
+
0,
|
844 |
+
training_args.gradient_accumulation_steps,
|
845 |
+
cumul_minibatch_step,
|
846 |
+
init_minibatch_step,
|
847 |
+
)
|
848 |
+
# sum -> mean
|
849 |
+
loss_grad = jax.tree_map(
|
850 |
+
lambda x: x / training_args.gradient_accumulation_steps, loss_grad
|
851 |
+
)
|
852 |
|
853 |
+
# update state
|
854 |
+
loss, grads = loss_grad
|
855 |
state = state.apply_gradients(
|
856 |
grads=grads,
|
857 |
+
dropout_rng=dropout_rng,
|
858 |
train_time=state.train_time + delta_time,
|
859 |
train_samples=state.train_samples + batch_size_per_step,
|
860 |
)
|
|
|
880 |
# Create parallel version of the train and eval step
|
881 |
p_train_step = pjit(
|
882 |
train_step,
|
883 |
+
in_axis_resources=(state_spec, grad_batch_spec, None),
|
884 |
out_axis_resources=(state_spec, None),
|
885 |
donate_argnums=(0,),
|
886 |
)
|
|
|
1061 |
delta_time = new_time - last_time
|
1062 |
last_time = new_time
|
1063 |
|
1064 |
+
# reshape data into (gradient_accumulation_steps, minibatch_size, ...)
|
1065 |
batch = jax.tree_map(
|
1066 |
lambda x: x.reshape(
|
1067 |
(
|
|
|
1068 |
training_args.gradient_accumulation_steps,
|
1069 |
+
minibatch_size,
|
1070 |
)
|
1071 |
+ x.shape[1:]
|
1072 |
),
|