Spaces:
Running
Running
feat: add metrics + cleanup
Browse files- dev/seq2seq/run_seq2seq_flax.py +83 -81
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -21,6 +21,7 @@ Script adapted from run_summarization_flax.py
|
|
21 |
import os
|
22 |
import logging
|
23 |
import sys
|
|
|
24 |
from dataclasses import dataclass, field
|
25 |
from pathlib import Path
|
26 |
from typing import Callable, Optional
|
@@ -37,7 +38,6 @@ import optax
|
|
37 |
import transformers
|
38 |
from flax import jax_utils, traverse_util
|
39 |
from flax.serialization import from_bytes, to_bytes
|
40 |
-
import flax.linen as nn
|
41 |
from flax.jax_utils import unreplicate
|
42 |
from flax.training import train_state
|
43 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
@@ -136,14 +136,6 @@ class DataTrainingArguments:
|
|
136 |
default=False,
|
137 |
metadata={"help": "Whether to stream the dataset."},
|
138 |
)
|
139 |
-
len_train: Optional[int] = field(
|
140 |
-
default=None,
|
141 |
-
metadata={"help": "Length of training dataset, required for streaming"},
|
142 |
-
)
|
143 |
-
len_eval: Optional[int] = field(
|
144 |
-
default=None,
|
145 |
-
metadata={"help": "Length of validation dataset, required for streaming"},
|
146 |
-
)
|
147 |
max_source_length: Optional[int] = field(
|
148 |
default=128,
|
149 |
metadata={
|
@@ -189,10 +181,6 @@ class DataTrainingArguments:
|
|
189 |
default=False,
|
190 |
metadata={"help": "Log frequency for model"},
|
191 |
)
|
192 |
-
save_model_steps: Optional[int] = field(
|
193 |
-
default=5000,
|
194 |
-
metadata={"help": "For saving/logging the model more frequently"},
|
195 |
-
)
|
196 |
|
197 |
def __post_init__(self):
|
198 |
if self.dataset_repo_or_path is None:
|
@@ -201,6 +189,9 @@ class DataTrainingArguments:
|
|
201 |
|
202 |
class TrainState(train_state.TrainState):
|
203 |
dropout_rng: jnp.ndarray = None
|
|
|
|
|
|
|
204 |
|
205 |
def replicate(self):
|
206 |
return jax_utils.replicate(self).replace(
|
@@ -212,13 +203,17 @@ class TrainState(train_state.TrainState):
|
|
212 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
213 |
new_opt_state = from_bytes(self.opt_state, f.read())
|
214 |
|
215 |
-
# restore
|
216 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
217 |
training_state = json.load(f)
|
218 |
-
new_step = training_state["step"]
|
219 |
|
220 |
# replace state
|
221 |
-
return self.replace(
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
|
224 |
def data_loader(
|
@@ -259,16 +254,16 @@ def data_loader_streaming(dataset: Dataset, batch_size: int):
|
|
259 |
|
260 |
|
261 |
def create_learning_rate_fn(
|
262 |
-
train_ds_size: int,
|
263 |
-
train_batch_size: int,
|
264 |
-
num_train_epochs: int,
|
265 |
num_warmup_steps: int,
|
266 |
learning_rate: float,
|
267 |
use_decay: bool,
|
|
|
268 |
) -> Callable[[int], jnp.array]:
|
269 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
270 |
-
|
271 |
-
|
|
|
|
|
272 |
warmup_fn = optax.linear_schedule(
|
273 |
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
274 |
)
|
@@ -364,7 +359,6 @@ def main():
|
|
364 |
project="dalle-mini",
|
365 |
job_type="Seq2Seq",
|
366 |
config=parser.parse_args(),
|
367 |
-
save_code=True,
|
368 |
)
|
369 |
|
370 |
if model_args.from_checkpoint is not None:
|
@@ -562,35 +556,26 @@ def main():
|
|
562 |
)
|
563 |
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
|
564 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
|
|
565 |
if data_args.streaming:
|
566 |
-
|
567 |
-
if
|
568 |
-
data_args.max_train_samples is not None
|
569 |
-
and data_args.max_train_samples < len_train_dataset
|
570 |
-
):
|
571 |
len_train_dataset = data_args.max_train_samples
|
572 |
-
|
573 |
-
len_eval_dataset = data_args.len_eval
|
574 |
-
if (
|
575 |
-
data_args.max_eval_samples is not None
|
576 |
-
and data_args.max_eval_samples < len_eval_dataset
|
577 |
-
):
|
578 |
len_eval_dataset = data_args.max_eval_samples
|
579 |
else:
|
580 |
len_train_dataset = len(train_dataset)
|
581 |
len_eval_dataset = len(eval_dataset)
|
582 |
-
steps_per_epoch =
|
583 |
-
|
584 |
-
|
585 |
|
586 |
# Create learning rate schedule
|
587 |
learning_rate_fn = create_learning_rate_fn(
|
588 |
-
len_train_dataset,
|
589 |
-
train_batch_size,
|
590 |
-
training_args.num_train_epochs,
|
591 |
training_args.warmup_steps,
|
592 |
training_args.learning_rate,
|
593 |
data_args.use_decay,
|
|
|
594 |
)
|
595 |
|
596 |
# We use Optax's "masking" functionality to not apply weight decay
|
@@ -621,7 +606,7 @@ def main():
|
|
621 |
optimizer = optax.adafactor(
|
622 |
learning_rate=learning_rate_fn,
|
623 |
weight_decay_rate=training_args.weight_decay,
|
624 |
-
weight_decay_mask=decay_mask_fn
|
625 |
)
|
626 |
else:
|
627 |
optimizer = optax.adamw(
|
@@ -647,10 +632,9 @@ def main():
|
|
647 |
dropout_rng=dropout_rng,
|
648 |
)
|
649 |
if model_args.from_checkpoint is not None:
|
650 |
-
# restore optimizer state and
|
|
|
651 |
state = state.restore_state(artifact_dir)
|
652 |
-
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
653 |
-
# TODO: optimizer may use a different step for learning rate, we should serialize/restore entire state
|
654 |
|
655 |
# label smoothed cross entropy
|
656 |
def loss_fn(logits, labels):
|
@@ -659,7 +643,7 @@ def main():
|
|
659 |
return loss
|
660 |
|
661 |
# Define gradient update step fn
|
662 |
-
def train_step(state, batch):
|
663 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
664 |
|
665 |
def compute_loss(params, batch):
|
@@ -673,14 +657,20 @@ def main():
|
|
673 |
grad_fn = jax.value_and_grad(compute_loss)
|
674 |
loss, grads = grad_fn(state.params, batch)
|
675 |
grads = jax.lax.pmean(grads, "batch")
|
676 |
-
state = state.apply_gradients(
|
|
|
|
|
|
|
|
|
|
|
677 |
|
678 |
metrics = {
|
679 |
"loss": loss,
|
680 |
"learning_rate": learning_rate_fn(state.step),
|
681 |
}
|
682 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
683 |
-
|
|
|
684 |
|
685 |
# Define eval fn
|
686 |
def eval_step(params, batch):
|
@@ -697,10 +687,6 @@ def main():
|
|
697 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
698 |
p_eval_step = jax.pmap(eval_step, "batch")
|
699 |
|
700 |
-
# Replicate the train state on each device
|
701 |
-
del model._params
|
702 |
-
state = state.replicate()
|
703 |
-
|
704 |
logger.info("***** Running training *****")
|
705 |
logger.info(f" Num examples = {len_train_dataset}")
|
706 |
logger.info(f" Num Epochs = {num_epochs}")
|
@@ -710,13 +696,12 @@ def main():
|
|
710 |
logger.info(
|
711 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
712 |
)
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
717 |
|
718 |
# set default x-axis as 'train/step'
|
719 |
-
wandb_log({}, step=
|
720 |
wandb.define_metric("*", step_metric="train/step")
|
721 |
|
722 |
# add interesting config parameters
|
@@ -725,11 +710,12 @@ def main():
|
|
725 |
"len_train": len_train_dataset,
|
726 |
"len_eval": len_eval_dataset,
|
727 |
"batch_size_per_update": batch_size_per_update,
|
728 |
-
"total_steps": total_steps,
|
729 |
-
"total_optimization_steps": total_optimization_steps,
|
730 |
}
|
731 |
)
|
732 |
|
|
|
|
|
|
|
733 |
def run_evaluation():
|
734 |
# ======================== Evaluating ==============================
|
735 |
eval_metrics = []
|
@@ -755,7 +741,7 @@ def main():
|
|
755 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
756 |
|
757 |
# log metrics
|
758 |
-
wandb_log(eval_metrics, step=
|
759 |
|
760 |
# Print metrics and update progress bar
|
761 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -764,10 +750,9 @@ def main():
|
|
764 |
|
765 |
return eval_metrics
|
766 |
|
767 |
-
def run_save_model(state,
|
768 |
if jax.process_index() == 0:
|
769 |
-
params = jax.device_get(
|
770 |
-
|
771 |
# save model locally
|
772 |
model.save_pretrained(
|
773 |
training_args.output_dir,
|
@@ -778,24 +763,32 @@ def main():
|
|
778 |
tokenizer.save_pretrained(training_args.output_dir)
|
779 |
|
780 |
# save state
|
781 |
-
|
782 |
-
state = unreplicate(state)
|
783 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
784 |
-
f.write(to_bytes(
|
785 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
786 |
"w"
|
787 |
) as f:
|
788 |
-
json.dump(
|
|
|
|
|
|
|
|
|
|
|
|
|
789 |
|
790 |
# save to W&B
|
791 |
if data_args.log_model:
|
792 |
# save some space
|
793 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
794 |
-
c.cleanup(wandb.util.from_human_size("
|
795 |
|
796 |
-
metadata = {
|
|
|
|
|
|
|
797 |
if eval_metrics is not None:
|
798 |
-
metadata["eval
|
799 |
artifact = wandb.Artifact(
|
800 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
801 |
)
|
@@ -829,18 +822,19 @@ def main():
|
|
829 |
training_args.output_dir,
|
830 |
params=params,
|
831 |
push_to_hub=training_args.push_to_hub,
|
832 |
-
commit_message=f"Saving weights and logs
|
833 |
temp_dir=True, # avoid issues with being in a repository
|
834 |
)
|
835 |
|
|
|
836 |
for epoch in epochs:
|
|
|
837 |
# ======================== Training ================================
|
838 |
-
step
|
839 |
-
wandb_log({"train/epoch": epoch}, step=step)
|
840 |
|
841 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
842 |
if data_args.streaming:
|
843 |
-
train_dataset.set_epoch(epoch)
|
844 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
845 |
else:
|
846 |
rng, input_rng = jax.random.split(rng)
|
@@ -855,23 +849,31 @@ def main():
|
|
855 |
leave=False,
|
856 |
total=steps_per_epoch,
|
857 |
):
|
858 |
-
|
859 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
860 |
|
861 |
if step % data_args.log_interval == 0 and jax.process_index() == 0:
|
862 |
# log metrics
|
863 |
-
wandb_log(
|
864 |
|
|
|
865 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
866 |
-
run_evaluation()
|
867 |
|
868 |
-
if step %
|
869 |
-
run_save_model(state,
|
870 |
|
871 |
# log final train metrics
|
872 |
-
|
|
|
873 |
|
874 |
-
train_metric = unreplicate(train_metric)
|
875 |
epochs.write(
|
876 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
877 |
)
|
@@ -880,7 +882,7 @@ def main():
|
|
880 |
eval_metrics = run_evaluation()
|
881 |
|
882 |
# save checkpoint after each epoch
|
883 |
-
run_save_model(state,
|
884 |
|
885 |
|
886 |
if __name__ == "__main__":
|
|
|
21 |
import os
|
22 |
import logging
|
23 |
import sys
|
24 |
+
import time
|
25 |
from dataclasses import dataclass, field
|
26 |
from pathlib import Path
|
27 |
from typing import Callable, Optional
|
|
|
38 |
import transformers
|
39 |
from flax import jax_utils, traverse_util
|
40 |
from flax.serialization import from_bytes, to_bytes
|
|
|
41 |
from flax.jax_utils import unreplicate
|
42 |
from flax.training import train_state
|
43 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
|
136 |
default=False,
|
137 |
metadata={"help": "Whether to stream the dataset."},
|
138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
max_source_length: Optional[int] = field(
|
140 |
default=128,
|
141 |
metadata={
|
|
|
181 |
default=False,
|
182 |
metadata={"help": "Log frequency for model"},
|
183 |
)
|
|
|
|
|
|
|
|
|
184 |
|
185 |
def __post_init__(self):
|
186 |
if self.dataset_repo_or_path is None:
|
|
|
189 |
|
190 |
class TrainState(train_state.TrainState):
|
191 |
dropout_rng: jnp.ndarray = None
|
192 |
+
epoch: int = 0
|
193 |
+
train_time: float = 0.0 # total time the model trained
|
194 |
+
train_samples: int = 0 # number of samples seen
|
195 |
|
196 |
def replicate(self):
|
197 |
return jax_utils.replicate(self).replace(
|
|
|
203 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
204 |
new_opt_state = from_bytes(self.opt_state, f.read())
|
205 |
|
206 |
+
# restore other parameters
|
207 |
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
208 |
training_state = json.load(f)
|
|
|
209 |
|
210 |
# replace state
|
211 |
+
return self.replace(
|
212 |
+
opt_state=new_opt_state,
|
213 |
+
step=training_state["step"],
|
214 |
+
train_time=training_state["train_time"],
|
215 |
+
train_samples=training_state["train_samples"],
|
216 |
+
)
|
217 |
|
218 |
|
219 |
def data_loader(
|
|
|
254 |
|
255 |
|
256 |
def create_learning_rate_fn(
|
|
|
|
|
|
|
257 |
num_warmup_steps: int,
|
258 |
learning_rate: float,
|
259 |
use_decay: bool,
|
260 |
+
num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
|
261 |
) -> Callable[[int], jnp.array]:
|
262 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
263 |
+
if use_decay:
|
264 |
+
assert (
|
265 |
+
num_train_steps is not None
|
266 |
+
), "Learning rate with decay requires number of training steps"
|
267 |
warmup_fn = optax.linear_schedule(
|
268 |
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
269 |
)
|
|
|
359 |
project="dalle-mini",
|
360 |
job_type="Seq2Seq",
|
361 |
config=parser.parse_args(),
|
|
|
362 |
)
|
363 |
|
364 |
if model_args.from_checkpoint is not None:
|
|
|
556 |
)
|
557 |
batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
|
558 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
559 |
+
len_train_dataset, len_eval_dataset = None, None
|
560 |
if data_args.streaming:
|
561 |
+
# we don't know the length, let's just assume max_samples if defined
|
562 |
+
if data_args.max_train_samples is not None:
|
|
|
|
|
|
|
563 |
len_train_dataset = data_args.max_train_samples
|
564 |
+
if data_args.max_eval_samples is not None:
|
|
|
|
|
|
|
|
|
|
|
565 |
len_eval_dataset = data_args.max_eval_samples
|
566 |
else:
|
567 |
len_train_dataset = len(train_dataset)
|
568 |
len_eval_dataset = len(eval_dataset)
|
569 |
+
steps_per_epoch = (
|
570 |
+
len_train_dataset // train_batch_size if len_train_dataset is not None else None
|
571 |
+
)
|
572 |
|
573 |
# Create learning rate schedule
|
574 |
learning_rate_fn = create_learning_rate_fn(
|
|
|
|
|
|
|
575 |
training_args.warmup_steps,
|
576 |
training_args.learning_rate,
|
577 |
data_args.use_decay,
|
578 |
+
steps_per_epoch * num_epochs,
|
579 |
)
|
580 |
|
581 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
606 |
optimizer = optax.adafactor(
|
607 |
learning_rate=learning_rate_fn,
|
608 |
weight_decay_rate=training_args.weight_decay,
|
609 |
+
weight_decay_mask=decay_mask_fn,
|
610 |
)
|
611 |
else:
|
612 |
optimizer = optax.adamw(
|
|
|
632 |
dropout_rng=dropout_rng,
|
633 |
)
|
634 |
if model_args.from_checkpoint is not None:
|
635 |
+
# restore optimizer state and other parameters
|
636 |
+
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
637 |
state = state.restore_state(artifact_dir)
|
|
|
|
|
638 |
|
639 |
# label smoothed cross entropy
|
640 |
def loss_fn(logits, labels):
|
|
|
643 |
return loss
|
644 |
|
645 |
# Define gradient update step fn
|
646 |
+
def train_step(state, batch, delta_time):
|
647 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
648 |
|
649 |
def compute_loss(params, batch):
|
|
|
657 |
grad_fn = jax.value_and_grad(compute_loss)
|
658 |
loss, grads = grad_fn(state.params, batch)
|
659 |
grads = jax.lax.pmean(grads, "batch")
|
660 |
+
state = state.apply_gradients(
|
661 |
+
grads=grads,
|
662 |
+
dropout_rng=new_dropout_rng,
|
663 |
+
train_time=state.train_time + delta_time,
|
664 |
+
train_samples=state.train_samples + train_batch_size,
|
665 |
+
)
|
666 |
|
667 |
metrics = {
|
668 |
"loss": loss,
|
669 |
"learning_rate": learning_rate_fn(state.step),
|
670 |
}
|
671 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
672 |
+
|
673 |
+
return state, metrics
|
674 |
|
675 |
# Define eval fn
|
676 |
def eval_step(params, batch):
|
|
|
687 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
688 |
p_eval_step = jax.pmap(eval_step, "batch")
|
689 |
|
|
|
|
|
|
|
|
|
690 |
logger.info("***** Running training *****")
|
691 |
logger.info(f" Num examples = {len_train_dataset}")
|
692 |
logger.info(f" Num Epochs = {num_epochs}")
|
|
|
696 |
logger.info(
|
697 |
f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
|
698 |
)
|
699 |
+
epochs = tqdm(
|
700 |
+
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
701 |
+
)
|
|
|
702 |
|
703 |
# set default x-axis as 'train/step'
|
704 |
+
wandb_log({}, step=state.step)
|
705 |
wandb.define_metric("*", step_metric="train/step")
|
706 |
|
707 |
# add interesting config parameters
|
|
|
710 |
"len_train": len_train_dataset,
|
711 |
"len_eval": len_eval_dataset,
|
712 |
"batch_size_per_update": batch_size_per_update,
|
|
|
|
|
713 |
}
|
714 |
)
|
715 |
|
716 |
+
# replicate state on each device
|
717 |
+
state = state.replicate()
|
718 |
+
|
719 |
def run_evaluation():
|
720 |
# ======================== Evaluating ==============================
|
721 |
eval_metrics = []
|
|
|
741 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
742 |
|
743 |
# log metrics
|
744 |
+
wandb_log(eval_metrics, step=get_metrics(state.step), prefix="eval")
|
745 |
|
746 |
# Print metrics and update progress bar
|
747 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
750 |
|
751 |
return eval_metrics
|
752 |
|
753 |
+
def run_save_model(state, eval_metrics=None):
|
754 |
if jax.process_index() == 0:
|
755 |
+
params = jax.device_get(unreplicate(state.params))
|
|
|
756 |
# save model locally
|
757 |
model.save_pretrained(
|
758 |
training_args.output_dir,
|
|
|
763 |
tokenizer.save_pretrained(training_args.output_dir)
|
764 |
|
765 |
# save state
|
766 |
+
opt_state = unreplicate(state.opt_state)
|
|
|
767 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
768 |
+
f.write(to_bytes(opt_state))
|
769 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
770 |
"w"
|
771 |
) as f:
|
772 |
+
json.dump(
|
773 |
+
{
|
774 |
+
k: get_metrics(state[k])
|
775 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
776 |
+
},
|
777 |
+
f,
|
778 |
+
)
|
779 |
|
780 |
# save to W&B
|
781 |
if data_args.log_model:
|
782 |
# save some space
|
783 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
784 |
+
c.cleanup(wandb.util.from_human_size("10GB"))
|
785 |
|
786 |
+
metadata = {
|
787 |
+
k: get_metrics(state[k])
|
788 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
789 |
+
}
|
790 |
if eval_metrics is not None:
|
791 |
+
metadata["eval"] = eval_metrics
|
792 |
artifact = wandb.Artifact(
|
793 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
794 |
)
|
|
|
822 |
training_args.output_dir,
|
823 |
params=params,
|
824 |
push_to_hub=training_args.push_to_hub,
|
825 |
+
commit_message=f"Saving weights and logs at step {get_metrics(state.step)+1}",
|
826 |
temp_dir=True, # avoid issues with being in a repository
|
827 |
)
|
828 |
|
829 |
+
last_time = time.perf_counter()
|
830 |
for epoch in epochs:
|
831 |
+
state.replace(epoch=jax_utils.replicate(epoch))
|
832 |
# ======================== Training ================================
|
833 |
+
wandb_log({"train/epoch": epoch}, step=get_metrics(state.step))
|
|
|
834 |
|
835 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
836 |
if data_args.streaming:
|
837 |
+
train_dataset.set_epoch(epoch) # shuffle dataset
|
838 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
839 |
else:
|
840 |
rng, input_rng = jax.random.split(rng)
|
|
|
849 |
leave=False,
|
850 |
total=steps_per_epoch,
|
851 |
):
|
852 |
+
|
853 |
+
# calculate delta time (we have a lag of one step but it's ok)
|
854 |
+
new_time = time.perf_counter()
|
855 |
+
delta_time = new_time - last_time
|
856 |
+
last_time = new_time
|
857 |
+
|
858 |
+
# train step
|
859 |
+
state, train_metric = p_train_step(state, batch, delta_time)
|
860 |
+
step = get_metrics(state.step)
|
861 |
|
862 |
if step % data_args.log_interval == 0 and jax.process_index() == 0:
|
863 |
# log metrics
|
864 |
+
wandb_log(get_metrics(train_metric), step=step, prefix="train")
|
865 |
|
866 |
+
eval_metrics = None
|
867 |
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
868 |
+
eval_metrics = run_evaluation()
|
869 |
|
870 |
+
if step % training_args.save_steps == 0:
|
871 |
+
run_save_model(state, eval_metrics)
|
872 |
|
873 |
# log final train metrics
|
874 |
+
train_metric = get_metrics(train_metric)
|
875 |
+
wandb_log(train_metric, step=step, prefix="train")
|
876 |
|
|
|
877 |
epochs.write(
|
878 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
879 |
)
|
|
|
882 |
eval_metrics = run_evaluation()
|
883 |
|
884 |
# save checkpoint after each epoch
|
885 |
+
run_save_model(state, eval_metrics)
|
886 |
|
887 |
|
888 |
if __name__ == "__main__":
|