Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
•
7e48337
1
Parent(s):
2b2be9b
Tokenizer, config, model can be loaded from wandb.
Browse files
src/dalle_mini/model/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
from .configuration import DalleBartConfig
|
2 |
from .modeling import DalleBart
|
|
|
|
1 |
from .configuration import DalleBartConfig
|
2 |
from .modeling import DalleBart
|
3 |
+
from .tokenizer import DalleBartTokenizer
|
src/dalle_mini/model/configuration.py
CHANGED
@@ -18,10 +18,12 @@ import warnings
|
|
18 |
from transformers.configuration_utils import PretrainedConfig
|
19 |
from transformers.utils import logging
|
20 |
|
|
|
|
|
21 |
logger = logging.get_logger(__name__)
|
22 |
|
23 |
|
24 |
-
class DalleBartConfig(PretrainedConfig):
|
25 |
model_type = "dallebart"
|
26 |
keys_to_ignore_at_inference = ["past_key_values"]
|
27 |
attribute_map = {
|
|
|
18 |
from transformers.configuration_utils import PretrainedConfig
|
19 |
from transformers.utils import logging
|
20 |
|
21 |
+
from .wandb_pretrained import PretrainedFromWandbMixin
|
22 |
+
|
23 |
logger = logging.get_logger(__name__)
|
24 |
|
25 |
|
26 |
+
class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
27 |
model_type = "dallebart"
|
28 |
keys_to_ignore_at_inference = ["past_key_values"]
|
29 |
attribute_map = {
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -15,14 +15,12 @@
|
|
15 |
""" DalleBart model. """
|
16 |
|
17 |
import math
|
18 |
-
import os
|
19 |
from functools import partial
|
20 |
from typing import Optional, Tuple
|
21 |
|
22 |
import flax.linen as nn
|
23 |
import jax
|
24 |
import jax.numpy as jnp
|
25 |
-
import wandb
|
26 |
from flax.core.frozen_dict import unfreeze
|
27 |
from flax.linen import make_causal_mask
|
28 |
from flax.traverse_util import flatten_dict
|
@@ -48,6 +46,7 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
48 |
from transformers.utils import logging
|
49 |
|
50 |
from .configuration import DalleBartConfig
|
|
|
51 |
|
52 |
logger = logging.get_logger(__name__)
|
53 |
|
@@ -421,7 +420,9 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
|
|
421 |
)
|
422 |
|
423 |
|
424 |
-
class DalleBart(
|
|
|
|
|
425 |
"""
|
426 |
Edits:
|
427 |
- renamed from FlaxBartForConditionalGeneration
|
@@ -563,24 +564,3 @@ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
|
|
563 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
564 |
|
565 |
return outputs
|
566 |
-
|
567 |
-
@classmethod
|
568 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
569 |
-
"""
|
570 |
-
Initializes from a wandb artifact, or delegates loading to the superclass.
|
571 |
-
"""
|
572 |
-
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
573 |
-
pretrained_model_name_or_path
|
574 |
-
):
|
575 |
-
# wandb artifact
|
576 |
-
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
577 |
-
|
578 |
-
# we download everything, including opt_state, so we can resume training if needed
|
579 |
-
# see also: #120
|
580 |
-
pretrained_model_name_or_path = artifact.download()
|
581 |
-
|
582 |
-
model = super(DalleBart, cls).from_pretrained(
|
583 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
584 |
-
)
|
585 |
-
model.config.resolved_name_or_path = pretrained_model_name_or_path
|
586 |
-
return model
|
|
|
15 |
""" DalleBart model. """
|
16 |
|
17 |
import math
|
|
|
18 |
from functools import partial
|
19 |
from typing import Optional, Tuple
|
20 |
|
21 |
import flax.linen as nn
|
22 |
import jax
|
23 |
import jax.numpy as jnp
|
|
|
24 |
from flax.core.frozen_dict import unfreeze
|
25 |
from flax.linen import make_causal_mask
|
26 |
from flax.traverse_util import flatten_dict
|
|
|
46 |
from transformers.utils import logging
|
47 |
|
48 |
from .configuration import DalleBartConfig
|
49 |
+
from .wandb_pretrained import PretrainedFromWandbMixin
|
50 |
|
51 |
logger = logging.get_logger(__name__)
|
52 |
|
|
|
420 |
)
|
421 |
|
422 |
|
423 |
+
class DalleBart(
|
424 |
+
PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
|
425 |
+
):
|
426 |
"""
|
427 |
Edits:
|
428 |
- renamed from FlaxBartForConditionalGeneration
|
|
|
564 |
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
565 |
|
566 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/dalle_mini/model/tokenizer.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" DalleBart tokenizer """
|
2 |
+
from transformers import BartTokenizer
|
3 |
+
from transformers.utils import logging
|
4 |
+
|
5 |
+
from .wandb_pretrained import PretrainedFromWandbMixin
|
6 |
+
|
7 |
+
logger = logging.get_logger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizer):
|
11 |
+
pass
|
src/dalle_mini/model/wandb_pretrained.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import wandb
|
3 |
+
|
4 |
+
|
5 |
+
class PretrainedFromWandbMixin:
|
6 |
+
@classmethod
|
7 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
8 |
+
"""
|
9 |
+
Initializes from a wandb artifact, or delegates loading to the superclass.
|
10 |
+
"""
|
11 |
+
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
12 |
+
pretrained_model_name_or_path
|
13 |
+
):
|
14 |
+
# wandb artifact
|
15 |
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
16 |
+
pretrained_model_name_or_path = artifact.download()
|
17 |
+
|
18 |
+
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
19 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
20 |
+
)
|