Spaces:
Running
Running
feat(train): another 25% faster
Browse files- tools/train/train.py +21 -21
tools/train/train.py
CHANGED
@@ -36,10 +36,10 @@ import transformers
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
-
from flax.core.frozen_dict import FrozenDict, freeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
-
from flax.training.common_utils import onehot
|
43 |
from jax.experimental import PartitionSpec, maps
|
44 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
45 |
from tqdm import tqdm
|
@@ -382,7 +382,7 @@ class TrainState(train_state.TrainState):
|
|
382 |
|
383 |
class MetricsLogger:
|
384 |
def __init__(self, state):
|
385 |
-
self.step = state.step
|
386 |
self.time = time.perf_counter()
|
387 |
|
388 |
def get_all_train_metrics(self, train_metrics, state):
|
@@ -792,8 +792,7 @@ def main():
|
|
792 |
|
793 |
def compute_loss(params, minibatch, dropout_rng):
|
794 |
# minibatch has dim (batch_size, ...)
|
795 |
-
minibatch =
|
796 |
-
labels = minibatch.pop("labels")
|
797 |
logits = state.apply_fn(
|
798 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
799 |
)[0]
|
@@ -883,14 +882,10 @@ def main():
|
|
883 |
|
884 |
# Define eval fn
|
885 |
def eval_step(params, batch):
|
886 |
-
batch =
|
887 |
-
labels = batch.pop("labels")
|
888 |
logits = model(**batch, params=params, train=False)[0]
|
889 |
loss = loss_fn(logits, labels)
|
890 |
-
|
891 |
-
# summarize metrics
|
892 |
-
metrics = {"loss": loss}
|
893 |
-
return metrics
|
894 |
|
895 |
# Create parallel version of the train and eval step
|
896 |
p_train_step = pjit(
|
@@ -940,7 +935,6 @@ def main():
|
|
940 |
|
941 |
def run_evaluation():
|
942 |
# ======================== Evaluating ==============================
|
943 |
-
eval_metrics = []
|
944 |
if training_args.do_eval:
|
945 |
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
946 |
eval_steps = (
|
@@ -948,6 +942,7 @@ def main():
|
|
948 |
if len_eval_dataset is not None
|
949 |
else None
|
950 |
)
|
|
|
951 |
for batch in tqdm(
|
952 |
eval_loader,
|
953 |
desc="Evaluating...",
|
@@ -955,13 +950,15 @@ def main():
|
|
955 |
leave=False,
|
956 |
total=eval_steps,
|
957 |
):
|
958 |
-
#
|
959 |
-
|
960 |
-
|
|
|
961 |
|
962 |
-
#
|
963 |
-
|
964 |
-
|
|
|
965 |
|
966 |
# log metrics
|
967 |
metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
|
@@ -1050,6 +1047,7 @@ def main():
|
|
1050 |
# init variables
|
1051 |
last_time = time.perf_counter()
|
1052 |
train_metrics = None
|
|
|
1053 |
|
1054 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
1055 |
for epoch in epochs:
|
@@ -1088,10 +1086,12 @@ def main():
|
|
1088 |
),
|
1089 |
batch,
|
1090 |
)
|
|
|
|
|
1091 |
|
1092 |
# train step
|
1093 |
-
state, train_metrics = p_train_step(state,
|
1094 |
-
step
|
1095 |
|
1096 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
1097 |
all_metrics = metrics_logger.get_all_train_metrics(
|
@@ -1100,7 +1100,7 @@ def main():
|
|
1100 |
metrics_logger.log(all_metrics, step=step, prefix="train")
|
1101 |
|
1102 |
eval_metrics = None
|
1103 |
-
if
|
1104 |
eval_metrics = run_evaluation()
|
1105 |
|
1106 |
if step % training_args.save_steps == 0:
|
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
+
from flax.core.frozen_dict import FrozenDict, freeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
+
from flax.training.common_utils import onehot
|
43 |
from jax.experimental import PartitionSpec, maps
|
44 |
from jax.experimental.pjit import pjit, with_sharding_constraint
|
45 |
from tqdm import tqdm
|
|
|
382 |
|
383 |
class MetricsLogger:
|
384 |
def __init__(self, state):
|
385 |
+
self.step = int(state.step)
|
386 |
self.time = time.perf_counter()
|
387 |
|
388 |
def get_all_train_metrics(self, train_metrics, state):
|
|
|
792 |
|
793 |
def compute_loss(params, minibatch, dropout_rng):
|
794 |
# minibatch has dim (batch_size, ...)
|
795 |
+
minibatch, labels = minibatch.pop("labels")
|
|
|
796 |
logits = state.apply_fn(
|
797 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
798 |
)[0]
|
|
|
882 |
|
883 |
# Define eval fn
|
884 |
def eval_step(params, batch):
|
885 |
+
batch, labels = batch.pop("labels")
|
|
|
886 |
logits = model(**batch, params=params, train=False)[0]
|
887 |
loss = loss_fn(logits, labels)
|
888 |
+
return loss
|
|
|
|
|
|
|
889 |
|
890 |
# Create parallel version of the train and eval step
|
891 |
p_train_step = pjit(
|
|
|
935 |
|
936 |
def run_evaluation():
|
937 |
# ======================== Evaluating ==============================
|
|
|
938 |
if training_args.do_eval:
|
939 |
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
940 |
eval_steps = (
|
|
|
942 |
if len_eval_dataset is not None
|
943 |
else None
|
944 |
)
|
945 |
+
eval_loss = []
|
946 |
for batch in tqdm(
|
947 |
eval_loader,
|
948 |
desc="Evaluating...",
|
|
|
950 |
leave=False,
|
951 |
total=eval_steps,
|
952 |
):
|
953 |
+
# freeze batch to pass safely to JAX transforms
|
954 |
+
batch = freeze(batch)
|
955 |
+
# accumulate losses async
|
956 |
+
eval_loss.append(p_eval_step(state.params, batch))
|
957 |
|
958 |
+
# get the mean of the loss
|
959 |
+
eval_loss = jnp.stack(eval_loss)
|
960 |
+
eval_loss = jnp.mean(eval_loss)
|
961 |
+
eval_metrics = {"loss": eval_loss}
|
962 |
|
963 |
# log metrics
|
964 |
metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
|
|
|
1047 |
# init variables
|
1048 |
last_time = time.perf_counter()
|
1049 |
train_metrics = None
|
1050 |
+
step = int(state.step)
|
1051 |
|
1052 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
1053 |
for epoch in epochs:
|
|
|
1086 |
),
|
1087 |
batch,
|
1088 |
)
|
1089 |
+
# freeze batch to pass safely to jax transforms
|
1090 |
+
batch = freeze(batch)
|
1091 |
|
1092 |
# train step
|
1093 |
+
state, train_metrics = p_train_step(state, batch, delta_time)
|
1094 |
+
step += 1
|
1095 |
|
1096 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
1097 |
all_metrics = metrics_logger.get_all_train_metrics(
|
|
|
1100 |
metrics_logger.log(all_metrics, step=step, prefix="train")
|
1101 |
|
1102 |
eval_metrics = None
|
1103 |
+
if step % training_args.eval_steps == 0:
|
1104 |
eval_metrics = run_evaluation()
|
1105 |
|
1106 |
if step % training_args.save_steps == 0:
|