boris commited on
Commit
b257ca8
1 Parent(s): 1f57ad7

fix(train): update model name

Browse files
Files changed (1) hide show
  1. tools/train/train.py +3 -4
tools/train/train.py CHANGED
@@ -41,10 +41,9 @@ from flax.training import train_state
41
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
42
  from tqdm import tqdm
43
  from transformers import AutoTokenizer, HfArgumentParser
44
- from transformers.models.bart.modeling_flax_bart import BartConfig
45
 
46
  from dalle_mini.data import Dataset
47
- from dalle_mini.model import DalleBartConfig, DalleBartForConditionalGeneration
48
 
49
  logger = logging.getLogger(__name__)
50
 
@@ -418,7 +417,7 @@ def main():
418
 
419
  # Load or create new model
420
  if model_args.model_name_or_path:
421
- model = DalleBartForConditionalGeneration.from_pretrained(
422
  model_args.model_name_or_path,
423
  config=config,
424
  seed=training_args.seed_model,
@@ -427,7 +426,7 @@ def main():
427
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
428
  print(model.params)
429
  else:
430
- model = DalleBartForConditionalGeneration(
431
  config,
432
  seed=training_args.seed_model,
433
  dtype=getattr(jnp, model_args.dtype),
 
41
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
42
  from tqdm import tqdm
43
  from transformers import AutoTokenizer, HfArgumentParser
 
44
 
45
  from dalle_mini.data import Dataset
46
+ from dalle_mini.model import DalleBartConfig, DalleBart
47
 
48
  logger = logging.getLogger(__name__)
49
 
 
417
 
418
  # Load or create new model
419
  if model_args.model_name_or_path:
420
+ model = DalleBart.from_pretrained(
421
  model_args.model_name_or_path,
422
  config=config,
423
  seed=training_args.seed_model,
 
426
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
427
  print(model.params)
428
  else:
429
+ model = DalleBart(
430
  config,
431
  seed=training_args.seed_model,
432
  dtype=getattr(jnp, model_args.dtype),