Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
ae983d7
1
Parent(s):
7e48337
Use DalleBartTokenizer. State restoration reverted to previous method:
Browse filesexplicitly download artifact and use the download directory.
A better solution will be addressed in #120.
- tools/train/train.py +13 -8
tools/train/train.py
CHANGED
@@ -44,7 +44,7 @@ from tqdm import tqdm
|
|
44 |
from transformers import AutoTokenizer, HfArgumentParser
|
45 |
|
46 |
from dalle_mini.data import Dataset
|
47 |
-
from dalle_mini.model import DalleBart, DalleBartConfig
|
48 |
|
49 |
logger = logging.getLogger(__name__)
|
50 |
|
@@ -435,9 +435,15 @@ def main():
|
|
435 |
)
|
436 |
|
437 |
if training_args.resume_from_checkpoint is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
438 |
# load model
|
439 |
model = DalleBart.from_pretrained(
|
440 |
-
|
441 |
dtype=getattr(jnp, model_args.dtype),
|
442 |
abstract_init=True,
|
443 |
)
|
@@ -445,8 +451,8 @@ def main():
|
|
445 |
print(model.params)
|
446 |
|
447 |
# load tokenizer
|
448 |
-
tokenizer =
|
449 |
-
|
450 |
use_fast=True,
|
451 |
)
|
452 |
|
@@ -481,9 +487,8 @@ def main():
|
|
481 |
model_args.tokenizer_name, use_fast=True
|
482 |
)
|
483 |
else:
|
484 |
-
|
485 |
-
|
486 |
-
model.config.resolved_name_or_path,
|
487 |
use_fast=True,
|
488 |
)
|
489 |
|
@@ -621,7 +626,7 @@ def main():
|
|
621 |
if training_args.resume_from_checkpoint is not None:
|
622 |
# restore optimizer state and other parameters
|
623 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
624 |
-
state = state.restore_state(
|
625 |
|
626 |
# label smoothed cross entropy
|
627 |
def loss_fn(logits, labels):
|
|
|
44 |
from transformers import AutoTokenizer, HfArgumentParser
|
45 |
|
46 |
from dalle_mini.data import Dataset
|
47 |
+
from dalle_mini.model import DalleBart, DalleBartConfig, DalleBartTokenizer
|
48 |
|
49 |
logger = logging.getLogger(__name__)
|
50 |
|
|
|
435 |
)
|
436 |
|
437 |
if training_args.resume_from_checkpoint is not None:
|
438 |
+
if jax.process_index() == 0:
|
439 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
440 |
+
else:
|
441 |
+
artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
|
442 |
+
artifact_dir = artifact.download()
|
443 |
+
|
444 |
# load model
|
445 |
model = DalleBart.from_pretrained(
|
446 |
+
artifact_dir,
|
447 |
dtype=getattr(jnp, model_args.dtype),
|
448 |
abstract_init=True,
|
449 |
)
|
|
|
451 |
print(model.params)
|
452 |
|
453 |
# load tokenizer
|
454 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
455 |
+
artifact_dir,
|
456 |
use_fast=True,
|
457 |
)
|
458 |
|
|
|
487 |
model_args.tokenizer_name, use_fast=True
|
488 |
)
|
489 |
else:
|
490 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
491 |
+
model_args.model_name_or_path,
|
|
|
492 |
use_fast=True,
|
493 |
)
|
494 |
|
|
|
626 |
if training_args.resume_from_checkpoint is not None:
|
627 |
# restore optimizer state and other parameters
|
628 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
629 |
+
state = state.restore_state(artifact_dir)
|
630 |
|
631 |
# label smoothed cross entropy
|
632 |
def loss_fn(logits, labels):
|