Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
bb3f53e
1
Parent(s):
08dd098
Update `resume_from_checkpoint` to use `from_pretrained`.
Browse files- tools/train/train.py +3 -9
tools/train/train.py
CHANGED
@@ -434,22 +434,16 @@ def main():
|
|
434 |
)
|
435 |
|
436 |
if training_args.resume_from_checkpoint is not None:
|
437 |
-
if jax.process_index() == 0:
|
438 |
-
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
439 |
-
else:
|
440 |
-
artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
|
441 |
-
artifact_dir = artifact.download()
|
442 |
-
|
443 |
# load model
|
444 |
model = DalleBart.from_pretrained(
|
445 |
-
|
446 |
)
|
447 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
448 |
print(model.params)
|
449 |
|
450 |
# load tokenizer
|
451 |
tokenizer = AutoTokenizer.from_pretrained(
|
452 |
-
|
453 |
use_fast=True,
|
454 |
)
|
455 |
|
@@ -624,7 +618,7 @@ def main():
|
|
624 |
if training_args.resume_from_checkpoint is not None:
|
625 |
# restore optimizer state and other parameters
|
626 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
627 |
-
state = state.restore_state(
|
628 |
|
629 |
# label smoothed cross entropy
|
630 |
def loss_fn(logits, labels):
|
|
|
434 |
)
|
435 |
|
436 |
if training_args.resume_from_checkpoint is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
# load model
|
438 |
model = DalleBart.from_pretrained(
|
439 |
+
training_args.resume_from_checkpoint, dtype=getattr(jnp, model_args.dtype), abstract_init=True
|
440 |
)
|
441 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
442 |
print(model.params)
|
443 |
|
444 |
# load tokenizer
|
445 |
tokenizer = AutoTokenizer.from_pretrained(
|
446 |
+
model.config.resolved_name_or_path,
|
447 |
use_fast=True,
|
448 |
)
|
449 |
|
|
|
618 |
if training_args.resume_from_checkpoint is not None:
|
619 |
# restore optimizer state and other parameters
|
620 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
621 |
+
state = state.restore_state(model.config.resolved_name_or_path)
|
622 |
|
623 |
# label smoothed cross entropy
|
624 |
def loss_fn(logits, labels):
|