boris commited on
Commit
f69b21b
2 Parent(s): bbbf7c8 f9d51f7

Load from wandb artifact (#121)

Browse files

Load model & tokenizer from artifacts.
Fixes #116

src/dalle_mini/model/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
 
 
1
  from .configuration import DalleBartConfig
2
  from .modeling import DalleBart
3
+ from .tokenizer import DalleBartTokenizer
src/dalle_mini/model/configuration.py CHANGED
@@ -18,10 +18,12 @@ import warnings
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
 
 
21
  logger = logging.get_logger(__name__)
22
 
23
 
24
- class DalleBartConfig(PretrainedConfig):
25
  model_type = "dallebart"
26
  keys_to_ignore_at_inference = ["past_key_values"]
27
  attribute_map = {
 
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.utils import logging
20
 
21
+ from .wandb_pretrained import PretrainedFromWandbMixin
22
+
23
  logger = logging.get_logger(__name__)
24
 
25
 
26
+ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
27
  model_type = "dallebart"
28
  keys_to_ignore_at_inference = ["past_key_values"]
29
  attribute_map = {
src/dalle_mini/model/modeling.py CHANGED
@@ -46,6 +46,7 @@ from transformers.models.bart.modeling_flax_bart import (
46
  from transformers.utils import logging
47
 
48
  from .configuration import DalleBartConfig
 
49
 
50
  logger = logging.get_logger(__name__)
51
 
@@ -419,7 +420,9 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
419
  )
420
 
421
 
422
- class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
 
 
423
  """
424
  Edits:
425
  - renamed from FlaxBartForConditionalGeneration
 
46
  from transformers.utils import logging
47
 
48
  from .configuration import DalleBartConfig
49
+ from .wandb_pretrained import PretrainedFromWandbMixin
50
 
51
  logger = logging.get_logger(__name__)
52
 
 
420
  )
421
 
422
 
423
+ class DalleBart(
424
+ PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
425
+ ):
426
  """
427
  Edits:
428
  - renamed from FlaxBartForConditionalGeneration
src/dalle_mini/model/tokenizer.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart tokenizer """
2
+ from transformers import BartTokenizer
3
+ from transformers.utils import logging
4
+
5
+ from .wandb_pretrained import PretrainedFromWandbMixin
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizer):
11
+ pass
src/dalle_mini/model/wandb_pretrained.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import wandb
4
+
5
+
6
+ class PretrainedFromWandbMixin:
7
+ @classmethod
8
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
9
+ """
10
+ Initializes from a wandb artifact, or delegates loading to the superclass.
11
+ """
12
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
13
+ pretrained_model_name_or_path
14
+ ):
15
+ # wandb artifact
16
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
17
+ pretrained_model_name_or_path = artifact.download()
18
+
19
+ return super(PretrainedFromWandbMixin, cls).from_pretrained(
20
+ pretrained_model_name_or_path, *model_args, **kwargs
21
+ )
tools/train/train.py CHANGED
@@ -44,7 +44,7 @@ from tqdm import tqdm
44
  from transformers import AutoTokenizer, HfArgumentParser
45
 
46
  from dalle_mini.data import Dataset
47
- from dalle_mini.model import DalleBart, DalleBartConfig
48
 
49
  logger = logging.getLogger(__name__)
50
 
@@ -58,8 +58,9 @@ class ModelArguments:
58
  model_name_or_path: Optional[str] = field(
59
  default=None,
60
  metadata={
61
- "help": "The model checkpoint for weights initialization."
62
- "Don't set if you want to train a model from scratch."
 
63
  },
64
  )
65
  config_name: Optional[str] = field(
@@ -482,13 +483,15 @@ def main():
482
 
483
  # load model
484
  model = DalleBart.from_pretrained(
485
- artifact_dir, dtype=getattr(jnp, model_args.dtype), abstract_init=True
 
 
486
  )
487
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
488
  print(model.params)
489
 
490
  # load tokenizer
491
- tokenizer = AutoTokenizer.from_pretrained(
492
  artifact_dir,
493
  use_fast=True,
494
  )
@@ -498,7 +501,7 @@ def main():
498
  if model_args.config_name:
499
  config = DalleBartConfig.from_pretrained(model_args.config_name)
500
  else:
501
- config = DalleBartConfig.from_pretrained(model_args.model_name_or_path)
502
 
503
  # Load or create new model
504
  if model_args.model_name_or_path:
@@ -524,7 +527,7 @@ def main():
524
  model_args.tokenizer_name, use_fast=True
525
  )
526
  else:
527
- tokenizer = AutoTokenizer.from_pretrained(
528
  model_args.model_name_or_path,
529
  use_fast=True,
530
  )
 
44
  from transformers import AutoTokenizer, HfArgumentParser
45
 
46
  from dalle_mini.data import Dataset
47
+ from dalle_mini.model import DalleBart, DalleBartConfig, DalleBartTokenizer
48
 
49
  logger = logging.getLogger(__name__)
50
 
 
58
  model_name_or_path: Optional[str] = field(
59
  default=None,
60
  metadata={
61
+ "help": "The model checkpoint for weights initialization. "
62
+ "Don't set if you want to train a model from scratch. "
63
+ "W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`."
64
  },
65
  )
66
  config_name: Optional[str] = field(
 
483
 
484
  # load model
485
  model = DalleBart.from_pretrained(
486
+ artifact_dir,
487
+ dtype=getattr(jnp, model_args.dtype),
488
+ abstract_init=True,
489
  )
490
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
491
  print(model.params)
492
 
493
  # load tokenizer
494
+ tokenizer = DalleBartTokenizer.from_pretrained(
495
  artifact_dir,
496
  use_fast=True,
497
  )
 
501
  if model_args.config_name:
502
  config = DalleBartConfig.from_pretrained(model_args.config_name)
503
  else:
504
+ config = None
505
 
506
  # Load or create new model
507
  if model_args.model_name_or_path:
 
527
  model_args.tokenizer_name, use_fast=True
528
  )
529
  else:
530
+ tokenizer = DalleBartTokenizer.from_pretrained(
531
  model_args.model_name_or_path,
532
  use_fast=True,
533
  )