boris commited on
Commit
32f4ba5
1 Parent(s): 5bd4c20

feat: force final ln in encoder

Browse files
src/dalle_mini/model/configuration.py CHANGED
@@ -60,12 +60,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
60
  # transformer variants
61
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
62
  ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
63
- head_scale=True, # used in NormFormer
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
67
- use_glu=True, # "GLU Variants Improve Transformer"
68
- use_all_scale=True, # use scale in layernorm even when seemingly unnecessary
 
 
69
  **kwargs,
70
  ):
71
  # text normalizer
@@ -91,7 +93,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
91
  self.tau_init = tau_init
92
  self.use_deepnet_scaling = use_deepnet_scaling
93
  self.use_glu = use_glu
94
- self.use_all_scale = use_all_scale
 
95
 
96
  # common parameters
97
  self.encoder_vocab_size = encoder_vocab_size
 
60
  # transformer variants
61
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
62
  ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
63
+ head_scale=False, # used in NormFormer
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
67
+ use_glu=False, # "GLU Variants Improve Transformer"
68
+ # parameters that should not be necessary but could affect results
69
+ force_ln_scale=True, # force scale in layernorm even when followed by dense layers
70
+ force_final_ln_encoder=False, # force layer normalization in encoder final layer even when followed by dense layers
71
  **kwargs,
72
  ):
73
  # text normalizer
 
93
  self.tau_init = tau_init
94
  self.use_deepnet_scaling = use_deepnet_scaling
95
  self.use_glu = use_glu
96
+ self.force_ln_scale = force_ln_scale
97
+ self.force_final_ln_encoder = force_final_ln_encoder
98
 
99
  # common parameters
100
  self.encoder_vocab_size = encoder_vocab_size
src/dalle_mini/model/modeling.py CHANGED
@@ -378,7 +378,7 @@ class GLU(nn.Module):
378
  self.config.ln_type,
379
  dtype=self.dtype,
380
  epsilon=1e-05,
381
- use_scale=self.config.use_all_scale,
382
  )(x)
383
  w = nn.Dense(
384
  self.ffn_dim,
@@ -403,7 +403,7 @@ class GLU(nn.Module):
403
  self.config.ln_type,
404
  dtype=self.dtype,
405
  epsilon=1e-05,
406
- use_scale=self.config.use_all_scale,
407
  )(x)
408
  x = nn.Dropout(rate=self.config.activation_dropout)(
409
  x, deterministic=deterministic
@@ -443,7 +443,7 @@ class FFN(nn.Module):
443
  self.config.ln_type,
444
  dtype=self.dtype,
445
  epsilon=1e-05,
446
- use_scale=self.config.use_all_scale,
447
  )(x)
448
  x = nn.Dense(
449
  self.ffn_dim,
@@ -459,7 +459,7 @@ class FFN(nn.Module):
459
  self.config.ln_type,
460
  dtype=self.dtype,
461
  epsilon=1e-05,
462
- use_scale=self.config.use_all_scale,
463
  )(x)
464
  x = nn.Dropout(rate=self.config.activation_dropout)(
465
  x, deterministic=deterministic
@@ -512,7 +512,7 @@ class FlaxBartEncoderLayer(nn.Module):
512
  self.config.ln_type,
513
  dtype=self.dtype,
514
  epsilon=1e-05,
515
- use_scale=self.config.use_all_scale,
516
  )(hidden_states)
517
  hidden_states, attn_weights = FlaxBartAttention(
518
  config=self.config,
@@ -561,7 +561,7 @@ class FlaxBartEncoderLayer(nn.Module):
561
  use_scale = (
562
  self.use_scale
563
  or self.config.ln_positions == "postln"
564
- or self.config.use_all_scale
565
  )
566
  hidden_states = norm(
567
  self.config.ln_type,
@@ -617,7 +617,7 @@ class FlaxBartDecoderLayer(nn.Module):
617
  self.config.ln_type,
618
  dtype=self.dtype,
619
  epsilon=1e-05,
620
- use_scale=self.config.use_all_scale,
621
  )(hidden_states)
622
  hidden_states, attn_weights = FlaxBartAttention(
623
  config=self.config,
@@ -656,7 +656,7 @@ class FlaxBartDecoderLayer(nn.Module):
656
  self.config.ln_type,
657
  dtype=self.dtype,
658
  epsilon=1e-05,
659
- use_scale=self.config.use_all_scale,
660
  )(hidden_states)
661
  hidden_states, cross_attn_weights = FlaxBartAttention(
662
  config=self.config,
@@ -709,7 +709,7 @@ class FlaxBartDecoderLayer(nn.Module):
709
  use_scale = (
710
  self.use_scale
711
  or self.config.ln_positions == "postln"
712
- or self.config.use_all_scale
713
  )
714
  hidden_states = norm(
715
  self.config.ln_type,
@@ -761,8 +761,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
761
  # or every 6 layers for Swin v2
762
  # not needed for other models which use layernorm before x-attention
763
  # ignored args for deepnet which always add a norm with scale
764
- add_norm = self.config.ln_positions == "swinv2" and (
765
- (i == n_layers - 1) or ((i + 1) % 6 == 0)
 
766
  )
767
  # we don't need to scale the norm for the last layer
768
  use_scale = i != n_layers - 1
 
378
  self.config.ln_type,
379
  dtype=self.dtype,
380
  epsilon=1e-05,
381
+ use_scale=self.config.force_ln_scale,
382
  )(x)
383
  w = nn.Dense(
384
  self.ffn_dim,
 
403
  self.config.ln_type,
404
  dtype=self.dtype,
405
  epsilon=1e-05,
406
+ use_scale=self.config.force_ln_scale,
407
  )(x)
408
  x = nn.Dropout(rate=self.config.activation_dropout)(
409
  x, deterministic=deterministic
 
443
  self.config.ln_type,
444
  dtype=self.dtype,
445
  epsilon=1e-05,
446
+ use_scale=self.config.force_ln_scale,
447
  )(x)
448
  x = nn.Dense(
449
  self.ffn_dim,
 
459
  self.config.ln_type,
460
  dtype=self.dtype,
461
  epsilon=1e-05,
462
+ use_scale=self.config.force_ln_scale,
463
  )(x)
464
  x = nn.Dropout(rate=self.config.activation_dropout)(
465
  x, deterministic=deterministic
 
512
  self.config.ln_type,
513
  dtype=self.dtype,
514
  epsilon=1e-05,
515
+ use_scale=self.config.force_ln_scale,
516
  )(hidden_states)
517
  hidden_states, attn_weights = FlaxBartAttention(
518
  config=self.config,
 
561
  use_scale = (
562
  self.use_scale
563
  or self.config.ln_positions == "postln"
564
+ or self.config.force_ln_scale
565
  )
566
  hidden_states = norm(
567
  self.config.ln_type,
 
617
  self.config.ln_type,
618
  dtype=self.dtype,
619
  epsilon=1e-05,
620
+ use_scale=self.config.force_ln_scale,
621
  )(hidden_states)
622
  hidden_states, attn_weights = FlaxBartAttention(
623
  config=self.config,
 
656
  self.config.ln_type,
657
  dtype=self.dtype,
658
  epsilon=1e-05,
659
+ use_scale=self.config.force_ln_scale,
660
  )(hidden_states)
661
  hidden_states, cross_attn_weights = FlaxBartAttention(
662
  config=self.config,
 
709
  use_scale = (
710
  self.use_scale
711
  or self.config.ln_positions == "postln"
712
+ or self.config.force_ln_scale
713
  )
714
  hidden_states = norm(
715
  self.config.ln_type,
 
761
  # or every 6 layers for Swin v2
762
  # not needed for other models which use layernorm before x-attention
763
  # ignored args for deepnet which always add a norm with scale
764
+ add_norm = self.config.force_final_ln_encoder or (
765
+ self.config.ln_positions == "swinv2"
766
+ and ((i == n_layers - 1) or ((i + 1) % 6 == 0))
767
  )
768
  # we don't need to scale the norm for the last layer
769
  use_scale = i != n_layers - 1