Pedro Cuenca commited on
Commit
1023afa
1 Parent(s): f5dba1e

Override from_pretrained to support wandb artifacts.

Browse files
Files changed (1) hide show
  1. src/dalle_mini/model/modeling.py +16 -0
src/dalle_mini/model/modeling.py CHANGED
@@ -44,6 +44,7 @@ from transformers.models.bart.modeling_flax_bart import (
44
  FlaxBartPreTrainedModel,
45
  )
46
  from transformers.utils import logging
 
47
 
48
  from .configuration import DalleBartConfig
49
 
@@ -561,3 +562,18 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
561
  outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
562
 
563
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  FlaxBartPreTrainedModel,
45
  )
46
  from transformers.utils import logging
47
+ import wandb
48
 
49
  from .configuration import DalleBartConfig
50
 
 
562
  outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
563
 
564
  return outputs
565
+
566
+ @classmethod
567
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
568
+ """
569
+ Initializes from a wandb artifact, or delegates loading to the superclass.
570
+ """
571
+ if ':' in pretrained_model_name_or_path:
572
+ # wandb artifact
573
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
574
+
575
+ # we download everything, including opt_state, so we can resume training if needed
576
+ # see also: #120
577
+ pretrained_model_name_or_path = artifact.download()
578
+
579
+ return super(DalleBart, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)