Spaces:
Running
Running
feat(train): simplify tokenizer loading
Browse files- tools/train/train.py +12 -10
tools/train/train.py
CHANGED
@@ -55,7 +55,7 @@ from dalle_mini.model import (
|
|
55 |
)
|
56 |
|
57 |
cc.initialize_cache(
|
58 |
-
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
|
59 |
)
|
60 |
|
61 |
|
@@ -104,6 +104,11 @@ class ModelArguments:
|
|
104 |
state_artifact: str = field(init=False)
|
105 |
|
106 |
def __post_init__(self):
|
|
|
|
|
|
|
|
|
|
|
107 |
if self.restore_state:
|
108 |
assert self.model_name_or_path is not None and (
|
109 |
"/model-" in self.model_name_or_path
|
@@ -511,15 +516,9 @@ def main():
|
|
511 |
)
|
512 |
|
513 |
# Load tokenizer
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
)
|
518 |
-
else:
|
519 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
520 |
-
model_args.model_name_or_path,
|
521 |
-
use_fast=True,
|
522 |
-
)
|
523 |
|
524 |
# get PartitionSpec for model params (required to be a dict)
|
525 |
param_spec = set_partitions(model.params)
|
@@ -532,6 +531,9 @@ def main():
|
|
532 |
|
533 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
534 |
|
|
|
|
|
|
|
535 |
# Initialize our training
|
536 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
537 |
|
|
|
55 |
)
|
56 |
|
57 |
cc.initialize_cache(
|
58 |
+
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
|
59 |
)
|
60 |
|
61 |
|
|
|
104 |
state_artifact: str = field(init=False)
|
105 |
|
106 |
def __post_init__(self):
|
107 |
+
if self.tokenizer_name is None:
|
108 |
+
self.tokenizer_name == self.model_name_or_path
|
109 |
+
assert (
|
110 |
+
self.tokenizer_name is not None
|
111 |
+
), "Tokenizer name or model name/path needs to be specified"
|
112 |
if self.restore_state:
|
113 |
assert self.model_name_or_path is not None and (
|
114 |
"/model-" in self.model_name_or_path
|
|
|
516 |
)
|
517 |
|
518 |
# Load tokenizer
|
519 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
520 |
+
model_args.tokenizer_name, use_fast=True
|
521 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
# get PartitionSpec for model params (required to be a dict)
|
524 |
param_spec = set_partitions(model.params)
|
|
|
531 |
|
532 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
533 |
|
534 |
+
# no dropout (hardcoded)
|
535 |
+
model.config.dropout = 0.0
|
536 |
+
|
537 |
# Initialize our training
|
538 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
539 |
|