Spaces:
Running
Running
feat(model): allow bias (#152)
Browse files
src/dalle_mini/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
__version__ = "0.0.
|
2 |
|
3 |
from .model import DalleBart, DalleBartProcessor
|
|
|
1 |
+
__version__ = "0.0.4"
|
2 |
|
3 |
from .model import DalleBart, DalleBartProcessor
|
src/dalle_mini/model/configuration.py
CHANGED
@@ -58,6 +58,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
58 |
tie_word_embeddings=False, # different modalities and sizes
|
59 |
do_sample=True,
|
60 |
# transformer variants
|
|
|
61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
62 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
63 |
use_head_scale=False, # used in NormFormer
|
@@ -65,7 +66,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
67 |
use_glu=False, # "GLU Variants Improve Transformer"
|
68 |
-
use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
69 |
sinkhorn_iters=1, # used in SinkFormers
|
70 |
use_final_ln_encoder=False, # final layer normalization in encoder
|
71 |
use_final_ln_decoder=False, # final layer normalization in decoder
|
@@ -77,7 +78,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
77 |
self.normalize_text = normalize_text
|
78 |
|
79 |
# transformer variants
|
80 |
-
self.
|
81 |
assert ln_type in [
|
82 |
"rmsnorm",
|
83 |
"layernorm",
|
@@ -92,6 +93,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
92 |
"postln",
|
93 |
"preln",
|
94 |
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
|
|
|
95 |
assert use_alibi is False, "use_alibi is not supported yet"
|
96 |
self.ln_positions = ln_positions
|
97 |
self.use_cosine_attention = use_cosine_attention
|
|
|
58 |
tie_word_embeddings=False, # different modalities and sizes
|
59 |
do_sample=True,
|
60 |
# transformer variants
|
61 |
+
use_bias=False, # use bias in attention and dense layers (except for lm_head)
|
62 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
63 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
64 |
use_head_scale=False, # used in NormFormer
|
|
|
66 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
67 |
use_deepnet_scaling=False, # used in Deepnet
|
68 |
use_glu=False, # "GLU Variants Improve Transformer"
|
69 |
+
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
70 |
sinkhorn_iters=1, # used in SinkFormers
|
71 |
use_final_ln_encoder=False, # final layer normalization in encoder
|
72 |
use_final_ln_decoder=False, # final layer normalization in decoder
|
|
|
78 |
self.normalize_text = normalize_text
|
79 |
|
80 |
# transformer variants
|
81 |
+
self.use_bias = use_bias
|
82 |
assert ln_type in [
|
83 |
"rmsnorm",
|
84 |
"layernorm",
|
|
|
93 |
"postln",
|
94 |
"preln",
|
95 |
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
|
96 |
+
self.use_head_scale = use_head_scale
|
97 |
assert use_alibi is False, "use_alibi is not supported yet"
|
98 |
self.ln_positions = ln_positions
|
99 |
self.use_cosine_attention = use_cosine_attention
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -444,7 +444,7 @@ class GLU(nn.Module):
|
|
444 |
w = nn.Dense(
|
445 |
self.ffn_dim,
|
446 |
dtype=self.dtype,
|
447 |
-
use_bias=
|
448 |
kernel_init=deepnet_init(gain)
|
449 |
if self.config.use_deepnet_scaling
|
450 |
else jax.nn.initializers.normal(self.config.init_std),
|
@@ -453,7 +453,7 @@ class GLU(nn.Module):
|
|
453 |
v = nn.Dense(
|
454 |
self.ffn_dim,
|
455 |
dtype=self.dtype,
|
456 |
-
use_bias=
|
457 |
kernel_init=deepnet_init(gain)
|
458 |
if self.config.use_deepnet_scaling
|
459 |
else jax.nn.initializers.normal(self.config.init_std),
|
@@ -473,7 +473,7 @@ class GLU(nn.Module):
|
|
473 |
x = nn.Dense(
|
474 |
self.embed_dim,
|
475 |
dtype=self.dtype,
|
476 |
-
use_bias=
|
477 |
kernel_init=deepnet_init(gain)
|
478 |
if self.config.use_deepnet_scaling
|
479 |
else jax.nn.initializers.normal(self.config.init_std),
|
@@ -509,7 +509,7 @@ class FFN(nn.Module):
|
|
509 |
x = nn.Dense(
|
510 |
self.ffn_dim,
|
511 |
dtype=self.dtype,
|
512 |
-
use_bias=
|
513 |
kernel_init=deepnet_init(gain)
|
514 |
if self.config.use_deepnet_scaling
|
515 |
else jax.nn.initializers.normal(self.config.init_std),
|
@@ -528,7 +528,7 @@ class FFN(nn.Module):
|
|
528 |
x = nn.Dense(
|
529 |
self.embed_dim,
|
530 |
dtype=self.dtype,
|
531 |
-
use_bias=
|
532 |
kernel_init=deepnet_init(gain)
|
533 |
if self.config.use_deepnet_scaling
|
534 |
else jax.nn.initializers.normal(self.config.init_std),
|
@@ -580,7 +580,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
580 |
embed_dim=embed_dim,
|
581 |
num_heads=self.config.encoder_attention_heads,
|
582 |
dropout=self.config.attention_dropout,
|
583 |
-
bias=
|
584 |
dtype=self.dtype,
|
585 |
is_encoder=True,
|
586 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
@@ -686,7 +686,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
686 |
num_heads=self.config.decoder_attention_heads,
|
687 |
dropout=self.config.attention_dropout,
|
688 |
causal=True,
|
689 |
-
bias=
|
690 |
dtype=self.dtype,
|
691 |
is_encoder=False,
|
692 |
)(
|
@@ -724,7 +724,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
724 |
embed_dim=embed_dim,
|
725 |
num_heads=self.config.decoder_attention_heads,
|
726 |
dropout=self.config.attention_dropout,
|
727 |
-
bias=
|
728 |
dtype=self.dtype,
|
729 |
is_encoder=False,
|
730 |
)(
|
|
|
444 |
w = nn.Dense(
|
445 |
self.ffn_dim,
|
446 |
dtype=self.dtype,
|
447 |
+
use_bias=self.config.use_bias,
|
448 |
kernel_init=deepnet_init(gain)
|
449 |
if self.config.use_deepnet_scaling
|
450 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
453 |
v = nn.Dense(
|
454 |
self.ffn_dim,
|
455 |
dtype=self.dtype,
|
456 |
+
use_bias=self.config.use_bias,
|
457 |
kernel_init=deepnet_init(gain)
|
458 |
if self.config.use_deepnet_scaling
|
459 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
473 |
x = nn.Dense(
|
474 |
self.embed_dim,
|
475 |
dtype=self.dtype,
|
476 |
+
use_bias=self.config.use_bias,
|
477 |
kernel_init=deepnet_init(gain)
|
478 |
if self.config.use_deepnet_scaling
|
479 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
509 |
x = nn.Dense(
|
510 |
self.ffn_dim,
|
511 |
dtype=self.dtype,
|
512 |
+
use_bias=self.config.use_bias,
|
513 |
kernel_init=deepnet_init(gain)
|
514 |
if self.config.use_deepnet_scaling
|
515 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
528 |
x = nn.Dense(
|
529 |
self.embed_dim,
|
530 |
dtype=self.dtype,
|
531 |
+
use_bias=self.config.use_bias,
|
532 |
kernel_init=deepnet_init(gain)
|
533 |
if self.config.use_deepnet_scaling
|
534 |
else jax.nn.initializers.normal(self.config.init_std),
|
|
|
580 |
embed_dim=embed_dim,
|
581 |
num_heads=self.config.encoder_attention_heads,
|
582 |
dropout=self.config.attention_dropout,
|
583 |
+
bias=self.config.use_bias,
|
584 |
dtype=self.dtype,
|
585 |
is_encoder=True,
|
586 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
|
|
686 |
num_heads=self.config.decoder_attention_heads,
|
687 |
dropout=self.config.attention_dropout,
|
688 |
causal=True,
|
689 |
+
bias=self.config.use_bias,
|
690 |
dtype=self.dtype,
|
691 |
is_encoder=False,
|
692 |
)(
|
|
|
724 |
embed_dim=embed_dim,
|
725 |
num_heads=self.config.decoder_attention_heads,
|
726 |
dropout=self.config.attention_dropout,
|
727 |
+
bias=self.config.use_bias,
|
728 |
dtype=self.dtype,
|
729 |
is_encoder=False,
|
730 |
)(
|
tools/train/train.py
CHANGED
@@ -49,6 +49,7 @@ from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shamp
|
|
49 |
from tqdm import tqdm
|
50 |
from transformers import HfArgumentParser
|
51 |
|
|
|
52 |
from dalle_mini.data import Dataset
|
53 |
from dalle_mini.model import (
|
54 |
DalleBart,
|
@@ -675,6 +676,7 @@ def main():
|
|
675 |
"transformers": transformers.__version__,
|
676 |
"datasets": datasets.__version__,
|
677 |
"wandb": wandb.__version__,
|
|
|
678 |
},
|
679 |
}
|
680 |
)
|
|
|
49 |
from tqdm import tqdm
|
50 |
from transformers import HfArgumentParser
|
51 |
|
52 |
+
import dalle_mini
|
53 |
from dalle_mini.data import Dataset
|
54 |
from dalle_mini.model import (
|
55 |
DalleBart,
|
|
|
676 |
"transformers": transformers.__version__,
|
677 |
"datasets": datasets.__version__,
|
678 |
"wandb": wandb.__version__,
|
679 |
+
"dalle_mini": dalle_mini.__version__,
|
680 |
},
|
681 |
}
|
682 |
)
|