Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
1023afa
1
Parent(s):
f5dba1e
Override from_pretrained to support wandb artifacts.
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -44,6 +44,7 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
44 |
FlaxBartPreTrainedModel,
|
45 |
)
|
46 |
from transformers.utils import logging
|
|
|
47 |
|
48 |
from .configuration import DalleBartConfig
|
49 |
|
@@ -561,3 +562,18 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
|
|
561 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
562 |
|
563 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
FlaxBartPreTrainedModel,
|
45 |
)
|
46 |
from transformers.utils import logging
|
47 |
+
import wandb
|
48 |
|
49 |
from .configuration import DalleBartConfig
|
50 |
|
|
|
562 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
563 |
|
564 |
return outputs
|
565 |
+
|
566 |
+
@classmethod
|
567 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
568 |
+
"""
|
569 |
+
Initializes from a wandb artifact, or delegates loading to the superclass.
|
570 |
+
"""
|
571 |
+
if ':' in pretrained_model_name_or_path:
|
572 |
+
# wandb artifact
|
573 |
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
574 |
+
|
575 |
+
# we download everything, including opt_state, so we can resume training if needed
|
576 |
+
# see also: #120
|
577 |
+
pretrained_model_name_or_path = artifact.download()
|
578 |
+
|
579 |
+
return super(DalleBart, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|