Spaces:
Running
Running
fix: actually replace state
Browse files- dev/seq2seq/run_seq2seq_flax.py +10 -11
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -435,18 +435,16 @@ def main():
|
|
435 |
|
436 |
def restore_state(state, artifact_dir):
|
437 |
# restore optimizer state
|
438 |
-
|
439 |
-
|
440 |
-
opt_state = from_bytes(state.opt_state, f.read())
|
441 |
-
state.replace(opt_state=opt_state)
|
442 |
|
443 |
# restore steps
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
|
451 |
if model_args.from_checkpoint is not None:
|
452 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
@@ -668,7 +666,8 @@ def main():
|
|
668 |
)
|
669 |
if model_args.from_checkpoint is not None:
|
670 |
# restore optimizer state, step and optimizer_step
|
671 |
-
restore_state(state, artifact_dir)
|
|
|
672 |
|
673 |
# label smoothed cross entropy
|
674 |
def loss_fn(logits, labels):
|
|
|
435 |
|
436 |
def restore_state(state, artifact_dir):
|
437 |
# restore optimizer state
|
438 |
+
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
439 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
|
|
|
|
440 |
|
441 |
# restore steps
|
442 |
+
with (Path(artifact_dir) / 'training_state.json').open('r') as f:
|
443 |
+
training_state = json.load(f)
|
444 |
+
step = training_state['step']
|
445 |
+
optimizer_step = step // training_args.gradient_accumulation_steps
|
446 |
+
|
447 |
+
return step, optimizer_step, opt_state
|
448 |
|
449 |
if model_args.from_checkpoint is not None:
|
450 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
|
|
666 |
)
|
667 |
if model_args.from_checkpoint is not None:
|
668 |
# restore optimizer state, step and optimizer_step
|
669 |
+
step, optimizer_step, opt_state = restore_state(state, artifact_dir)
|
670 |
+
state = state.replace(step=step, optimizer_step=optimizer_step, opt_state=opt_state)
|
671 |
|
672 |
# label smoothed cross entropy
|
673 |
def loss_fn(logits, labels):
|