Spaces:
Running
Running
feat: force final ln in encoder
Browse files
src/dalle_mini/model/configuration.py
CHANGED
@@ -60,12 +60,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
60 |
# transformer variants
|
61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
62 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
63 |
-
head_scale=
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
67 |
-
use_glu=
|
68 |
-
|
|
|
|
|
69 |
**kwargs,
|
70 |
):
|
71 |
# text normalizer
|
@@ -91,7 +93,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
91 |
self.tau_init = tau_init
|
92 |
self.use_deepnet_scaling = use_deepnet_scaling
|
93 |
self.use_glu = use_glu
|
94 |
-
self.
|
|
|
95 |
|
96 |
# common parameters
|
97 |
self.encoder_vocab_size = encoder_vocab_size
|
|
|
60 |
# transformer variants
|
61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
62 |
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
63 |
+
head_scale=False, # used in NormFormer
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
67 |
+
use_glu=False, # "GLU Variants Improve Transformer"
|
68 |
+
# parameters that should not be necessary but could affect results
|
69 |
+
force_ln_scale=True, # force scale in layernorm even when followed by dense layers
|
70 |
+
force_final_ln_encoder=False, # force layer normalization in encoder final layer even when followed by dense layers
|
71 |
**kwargs,
|
72 |
):
|
73 |
# text normalizer
|
|
|
93 |
self.tau_init = tau_init
|
94 |
self.use_deepnet_scaling = use_deepnet_scaling
|
95 |
self.use_glu = use_glu
|
96 |
+
self.force_ln_scale = force_ln_scale
|
97 |
+
self.force_final_ln_encoder = force_final_ln_encoder
|
98 |
|
99 |
# common parameters
|
100 |
self.encoder_vocab_size = encoder_vocab_size
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -378,7 +378,7 @@ class GLU(nn.Module):
|
|
378 |
self.config.ln_type,
|
379 |
dtype=self.dtype,
|
380 |
epsilon=1e-05,
|
381 |
-
use_scale=self.config.
|
382 |
)(x)
|
383 |
w = nn.Dense(
|
384 |
self.ffn_dim,
|
@@ -403,7 +403,7 @@ class GLU(nn.Module):
|
|
403 |
self.config.ln_type,
|
404 |
dtype=self.dtype,
|
405 |
epsilon=1e-05,
|
406 |
-
use_scale=self.config.
|
407 |
)(x)
|
408 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
409 |
x, deterministic=deterministic
|
@@ -443,7 +443,7 @@ class FFN(nn.Module):
|
|
443 |
self.config.ln_type,
|
444 |
dtype=self.dtype,
|
445 |
epsilon=1e-05,
|
446 |
-
use_scale=self.config.
|
447 |
)(x)
|
448 |
x = nn.Dense(
|
449 |
self.ffn_dim,
|
@@ -459,7 +459,7 @@ class FFN(nn.Module):
|
|
459 |
self.config.ln_type,
|
460 |
dtype=self.dtype,
|
461 |
epsilon=1e-05,
|
462 |
-
use_scale=self.config.
|
463 |
)(x)
|
464 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
465 |
x, deterministic=deterministic
|
@@ -512,7 +512,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
512 |
self.config.ln_type,
|
513 |
dtype=self.dtype,
|
514 |
epsilon=1e-05,
|
515 |
-
use_scale=self.config.
|
516 |
)(hidden_states)
|
517 |
hidden_states, attn_weights = FlaxBartAttention(
|
518 |
config=self.config,
|
@@ -561,7 +561,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
561 |
use_scale = (
|
562 |
self.use_scale
|
563 |
or self.config.ln_positions == "postln"
|
564 |
-
or self.config.
|
565 |
)
|
566 |
hidden_states = norm(
|
567 |
self.config.ln_type,
|
@@ -617,7 +617,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
617 |
self.config.ln_type,
|
618 |
dtype=self.dtype,
|
619 |
epsilon=1e-05,
|
620 |
-
use_scale=self.config.
|
621 |
)(hidden_states)
|
622 |
hidden_states, attn_weights = FlaxBartAttention(
|
623 |
config=self.config,
|
@@ -656,7 +656,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
656 |
self.config.ln_type,
|
657 |
dtype=self.dtype,
|
658 |
epsilon=1e-05,
|
659 |
-
use_scale=self.config.
|
660 |
)(hidden_states)
|
661 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
662 |
config=self.config,
|
@@ -709,7 +709,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
709 |
use_scale = (
|
710 |
self.use_scale
|
711 |
or self.config.ln_positions == "postln"
|
712 |
-
or self.config.
|
713 |
)
|
714 |
hidden_states = norm(
|
715 |
self.config.ln_type,
|
@@ -761,8 +761,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
761 |
# or every 6 layers for Swin v2
|
762 |
# not needed for other models which use layernorm before x-attention
|
763 |
# ignored args for deepnet which always add a norm with scale
|
764 |
-
add_norm = self.config.
|
765 |
-
|
|
|
766 |
)
|
767 |
# we don't need to scale the norm for the last layer
|
768 |
use_scale = i != n_layers - 1
|
|
|
378 |
self.config.ln_type,
|
379 |
dtype=self.dtype,
|
380 |
epsilon=1e-05,
|
381 |
+
use_scale=self.config.force_ln_scale,
|
382 |
)(x)
|
383 |
w = nn.Dense(
|
384 |
self.ffn_dim,
|
|
|
403 |
self.config.ln_type,
|
404 |
dtype=self.dtype,
|
405 |
epsilon=1e-05,
|
406 |
+
use_scale=self.config.force_ln_scale,
|
407 |
)(x)
|
408 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
409 |
x, deterministic=deterministic
|
|
|
443 |
self.config.ln_type,
|
444 |
dtype=self.dtype,
|
445 |
epsilon=1e-05,
|
446 |
+
use_scale=self.config.force_ln_scale,
|
447 |
)(x)
|
448 |
x = nn.Dense(
|
449 |
self.ffn_dim,
|
|
|
459 |
self.config.ln_type,
|
460 |
dtype=self.dtype,
|
461 |
epsilon=1e-05,
|
462 |
+
use_scale=self.config.force_ln_scale,
|
463 |
)(x)
|
464 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
465 |
x, deterministic=deterministic
|
|
|
512 |
self.config.ln_type,
|
513 |
dtype=self.dtype,
|
514 |
epsilon=1e-05,
|
515 |
+
use_scale=self.config.force_ln_scale,
|
516 |
)(hidden_states)
|
517 |
hidden_states, attn_weights = FlaxBartAttention(
|
518 |
config=self.config,
|
|
|
561 |
use_scale = (
|
562 |
self.use_scale
|
563 |
or self.config.ln_positions == "postln"
|
564 |
+
or self.config.force_ln_scale
|
565 |
)
|
566 |
hidden_states = norm(
|
567 |
self.config.ln_type,
|
|
|
617 |
self.config.ln_type,
|
618 |
dtype=self.dtype,
|
619 |
epsilon=1e-05,
|
620 |
+
use_scale=self.config.force_ln_scale,
|
621 |
)(hidden_states)
|
622 |
hidden_states, attn_weights = FlaxBartAttention(
|
623 |
config=self.config,
|
|
|
656 |
self.config.ln_type,
|
657 |
dtype=self.dtype,
|
658 |
epsilon=1e-05,
|
659 |
+
use_scale=self.config.force_ln_scale,
|
660 |
)(hidden_states)
|
661 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
662 |
config=self.config,
|
|
|
709 |
use_scale = (
|
710 |
self.use_scale
|
711 |
or self.config.ln_positions == "postln"
|
712 |
+
or self.config.force_ln_scale
|
713 |
)
|
714 |
hidden_states = norm(
|
715 |
self.config.ln_type,
|
|
|
761 |
# or every 6 layers for Swin v2
|
762 |
# not needed for other models which use layernorm before x-attention
|
763 |
# ignored args for deepnet which always add a norm with scale
|
764 |
+
add_norm = self.config.force_final_ln_encoder or (
|
765 |
+
self.config.ln_positions == "swinv2"
|
766 |
+
and ((i == n_layers - 1) or ((i + 1) % 6 == 0))
|
767 |
)
|
768 |
# we don't need to scale the norm for the last layer
|
769 |
use_scale = i != n_layers - 1
|