Spaces:
Running
Running
feat(model): set default config for legacy models
Browse files- dalle_mini/model.py +5 -0
dalle_mini/model.py
CHANGED
@@ -46,6 +46,11 @@ class CustomFlaxBartForConditionalGenerationModule(
|
|
46 |
FlaxBartForConditionalGenerationModule
|
47 |
):
|
48 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
|
49 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
50 |
self.lm_head = nn.Dense(
|
51 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|
|
|
46 |
FlaxBartForConditionalGenerationModule
|
47 |
):
|
48 |
def setup(self):
|
49 |
+
# set default config
|
50 |
+
self.config.normalize_text = getattr(self.config, "normalize_text", False)
|
51 |
+
self.config.image_length = getattr(self.config, "image_length", 256)
|
52 |
+
self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
|
53 |
+
|
54 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
55 |
self.lm_head = nn.Dense(
|
56 |
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
|