Spaces:
Running
Running
fix: comment
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -753,6 +753,7 @@ def main():
|
|
753 |
# restore optimizer state and step
|
754 |
state = state.restore_state(artifact_dir)
|
755 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
|
|
756 |
|
757 |
# label smoothed cross entropy
|
758 |
def loss_fn(logits, labels):
|
@@ -937,7 +938,7 @@ def main():
|
|
937 |
for epoch in epochs:
|
938 |
# ======================== Training ================================
|
939 |
step = unreplicate(state.step)
|
940 |
-
|
941 |
|
942 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
943 |
if data_args.streaming:
|
|
|
753 |
# restore optimizer state and step
|
754 |
state = state.restore_state(artifact_dir)
|
755 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
756 |
+
# TODO: optimizer may use a different step for learning rate, we should serialize/restore entire state
|
757 |
|
758 |
# label smoothed cross entropy
|
759 |
def loss_fn(logits, labels):
|
|
|
938 |
for epoch in epochs:
|
939 |
# ======================== Training ================================
|
940 |
step = unreplicate(state.step)
|
941 |
+
wandb_log({"train/epoch": epoch}, step=step)
|
942 |
|
943 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
944 |
if data_args.streaming:
|