Spaces:
Running
Running
fix(train): update model name
Browse files- tools/train/train.py +3 -4
tools/train/train.py
CHANGED
@@ -41,10 +41,9 @@ from flax.training import train_state
|
|
41 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
42 |
from tqdm import tqdm
|
43 |
from transformers import AutoTokenizer, HfArgumentParser
|
44 |
-
from transformers.models.bart.modeling_flax_bart import BartConfig
|
45 |
|
46 |
from dalle_mini.data import Dataset
|
47 |
-
from dalle_mini.model import DalleBartConfig,
|
48 |
|
49 |
logger = logging.getLogger(__name__)
|
50 |
|
@@ -418,7 +417,7 @@ def main():
|
|
418 |
|
419 |
# Load or create new model
|
420 |
if model_args.model_name_or_path:
|
421 |
-
model =
|
422 |
model_args.model_name_or_path,
|
423 |
config=config,
|
424 |
seed=training_args.seed_model,
|
@@ -427,7 +426,7 @@ def main():
|
|
427 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
428 |
print(model.params)
|
429 |
else:
|
430 |
-
model =
|
431 |
config,
|
432 |
seed=training_args.seed_model,
|
433 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
41 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
42 |
from tqdm import tqdm
|
43 |
from transformers import AutoTokenizer, HfArgumentParser
|
|
|
44 |
|
45 |
from dalle_mini.data import Dataset
|
46 |
+
from dalle_mini.model import DalleBartConfig, DalleBart
|
47 |
|
48 |
logger = logging.getLogger(__name__)
|
49 |
|
|
|
417 |
|
418 |
# Load or create new model
|
419 |
if model_args.model_name_or_path:
|
420 |
+
model = DalleBart.from_pretrained(
|
421 |
model_args.model_name_or_path,
|
422 |
config=config,
|
423 |
seed=training_args.seed_model,
|
|
|
426 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
427 |
print(model.params)
|
428 |
else:
|
429 |
+
model = DalleBart(
|
430 |
config,
|
431 |
seed=training_args.seed_model,
|
432 |
dtype=getattr(jnp, model_args.dtype),
|