boris commited on
Commit
361a994
1 Parent(s): 02b2308

feat(model): allow bias (#152)

Browse files
src/dalle_mini/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- __version__ = "0.0.3"
2
 
3
  from .model import DalleBart, DalleBartProcessor
 
1
+ __version__ = "0.0.4"
2
 
3
  from .model import DalleBart, DalleBartProcessor
src/dalle_mini/model/configuration.py CHANGED
@@ -58,6 +58,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
58
  tie_word_embeddings=False, # different modalities and sizes
59
  do_sample=True,
60
  # transformer variants
 
61
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
62
  ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
63
  use_head_scale=False, # used in NormFormer
@@ -65,7 +66,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
67
  use_glu=False, # "GLU Variants Improve Transformer"
68
- use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
69
  sinkhorn_iters=1, # used in SinkFormers
70
  use_final_ln_encoder=False, # final layer normalization in encoder
71
  use_final_ln_decoder=False, # final layer normalization in decoder
@@ -77,7 +78,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
77
  self.normalize_text = normalize_text
78
 
79
  # transformer variants
80
- self.use_head_scale = use_head_scale # per Normformer
81
  assert ln_type in [
82
  "rmsnorm",
83
  "layernorm",
@@ -92,6 +93,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
92
  "postln",
93
  "preln",
94
  ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
 
95
  assert use_alibi is False, "use_alibi is not supported yet"
96
  self.ln_positions = ln_positions
97
  self.use_cosine_attention = use_cosine_attention
 
58
  tie_word_embeddings=False, # different modalities and sizes
59
  do_sample=True,
60
  # transformer variants
61
+ use_bias=False, # use bias in attention and dense layers (except for lm_head)
62
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63
  ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
64
  use_head_scale=False, # used in NormFormer
 
66
  tau_init=0.05, # used only in cosine attention (Swin v2)
67
  use_deepnet_scaling=False, # used in Deepnet
68
  use_glu=False, # "GLU Variants Improve Transformer"
69
+ use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
70
  sinkhorn_iters=1, # used in SinkFormers
71
  use_final_ln_encoder=False, # final layer normalization in encoder
72
  use_final_ln_decoder=False, # final layer normalization in decoder
 
78
  self.normalize_text = normalize_text
79
 
80
  # transformer variants
81
+ self.use_bias = use_bias
82
  assert ln_type in [
83
  "rmsnorm",
84
  "layernorm",
 
93
  "postln",
94
  "preln",
95
  ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
96
+ self.use_head_scale = use_head_scale
97
  assert use_alibi is False, "use_alibi is not supported yet"
98
  self.ln_positions = ln_positions
99
  self.use_cosine_attention = use_cosine_attention
src/dalle_mini/model/modeling.py CHANGED
@@ -444,7 +444,7 @@ class GLU(nn.Module):
444
  w = nn.Dense(
445
  self.ffn_dim,
446
  dtype=self.dtype,
447
- use_bias=False,
448
  kernel_init=deepnet_init(gain)
449
  if self.config.use_deepnet_scaling
450
  else jax.nn.initializers.normal(self.config.init_std),
@@ -453,7 +453,7 @@ class GLU(nn.Module):
453
  v = nn.Dense(
454
  self.ffn_dim,
455
  dtype=self.dtype,
456
- use_bias=False,
457
  kernel_init=deepnet_init(gain)
458
  if self.config.use_deepnet_scaling
459
  else jax.nn.initializers.normal(self.config.init_std),
@@ -473,7 +473,7 @@ class GLU(nn.Module):
473
  x = nn.Dense(
474
  self.embed_dim,
475
  dtype=self.dtype,
476
- use_bias=False,
477
  kernel_init=deepnet_init(gain)
478
  if self.config.use_deepnet_scaling
479
  else jax.nn.initializers.normal(self.config.init_std),
@@ -509,7 +509,7 @@ class FFN(nn.Module):
509
  x = nn.Dense(
510
  self.ffn_dim,
511
  dtype=self.dtype,
512
- use_bias=False,
513
  kernel_init=deepnet_init(gain)
514
  if self.config.use_deepnet_scaling
515
  else jax.nn.initializers.normal(self.config.init_std),
@@ -528,7 +528,7 @@ class FFN(nn.Module):
528
  x = nn.Dense(
529
  self.embed_dim,
530
  dtype=self.dtype,
531
- use_bias=False,
532
  kernel_init=deepnet_init(gain)
533
  if self.config.use_deepnet_scaling
534
  else jax.nn.initializers.normal(self.config.init_std),
@@ -580,7 +580,7 @@ class FlaxBartEncoderLayer(nn.Module):
580
  embed_dim=embed_dim,
581
  num_heads=self.config.encoder_attention_heads,
582
  dropout=self.config.attention_dropout,
583
- bias=False,
584
  dtype=self.dtype,
585
  is_encoder=True,
586
  )(hidden_states=hidden_states, attention_mask=attention_mask)
@@ -686,7 +686,7 @@ class FlaxBartDecoderLayer(nn.Module):
686
  num_heads=self.config.decoder_attention_heads,
687
  dropout=self.config.attention_dropout,
688
  causal=True,
689
- bias=False,
690
  dtype=self.dtype,
691
  is_encoder=False,
692
  )(
@@ -724,7 +724,7 @@ class FlaxBartDecoderLayer(nn.Module):
724
  embed_dim=embed_dim,
725
  num_heads=self.config.decoder_attention_heads,
726
  dropout=self.config.attention_dropout,
727
- bias=False,
728
  dtype=self.dtype,
729
  is_encoder=False,
730
  )(
 
444
  w = nn.Dense(
445
  self.ffn_dim,
446
  dtype=self.dtype,
447
+ use_bias=self.config.use_bias,
448
  kernel_init=deepnet_init(gain)
449
  if self.config.use_deepnet_scaling
450
  else jax.nn.initializers.normal(self.config.init_std),
 
453
  v = nn.Dense(
454
  self.ffn_dim,
455
  dtype=self.dtype,
456
+ use_bias=self.config.use_bias,
457
  kernel_init=deepnet_init(gain)
458
  if self.config.use_deepnet_scaling
459
  else jax.nn.initializers.normal(self.config.init_std),
 
473
  x = nn.Dense(
474
  self.embed_dim,
475
  dtype=self.dtype,
476
+ use_bias=self.config.use_bias,
477
  kernel_init=deepnet_init(gain)
478
  if self.config.use_deepnet_scaling
479
  else jax.nn.initializers.normal(self.config.init_std),
 
509
  x = nn.Dense(
510
  self.ffn_dim,
511
  dtype=self.dtype,
512
+ use_bias=self.config.use_bias,
513
  kernel_init=deepnet_init(gain)
514
  if self.config.use_deepnet_scaling
515
  else jax.nn.initializers.normal(self.config.init_std),
 
528
  x = nn.Dense(
529
  self.embed_dim,
530
  dtype=self.dtype,
531
+ use_bias=self.config.use_bias,
532
  kernel_init=deepnet_init(gain)
533
  if self.config.use_deepnet_scaling
534
  else jax.nn.initializers.normal(self.config.init_std),
 
580
  embed_dim=embed_dim,
581
  num_heads=self.config.encoder_attention_heads,
582
  dropout=self.config.attention_dropout,
583
+ bias=self.config.use_bias,
584
  dtype=self.dtype,
585
  is_encoder=True,
586
  )(hidden_states=hidden_states, attention_mask=attention_mask)
 
686
  num_heads=self.config.decoder_attention_heads,
687
  dropout=self.config.attention_dropout,
688
  causal=True,
689
+ bias=self.config.use_bias,
690
  dtype=self.dtype,
691
  is_encoder=False,
692
  )(
 
724
  embed_dim=embed_dim,
725
  num_heads=self.config.decoder_attention_heads,
726
  dropout=self.config.attention_dropout,
727
+ bias=self.config.use_bias,
728
  dtype=self.dtype,
729
  is_encoder=False,
730
  )(
tools/train/train.py CHANGED
@@ -49,6 +49,7 @@ from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shamp
49
  from tqdm import tqdm
50
  from transformers import HfArgumentParser
51
 
 
52
  from dalle_mini.data import Dataset
53
  from dalle_mini.model import (
54
  DalleBart,
@@ -675,6 +676,7 @@ def main():
675
  "transformers": transformers.__version__,
676
  "datasets": datasets.__version__,
677
  "wandb": wandb.__version__,
 
678
  },
679
  }
680
  )
 
49
  from tqdm import tqdm
50
  from transformers import HfArgumentParser
51
 
52
+ import dalle_mini
53
  from dalle_mini.data import Dataset
54
  from dalle_mini.model import (
55
  DalleBart,
 
676
  "transformers": transformers.__version__,
677
  "datasets": datasets.__version__,
678
  "wandb": wandb.__version__,
679
+ "dalle_mini": dalle_mini.__version__,
680
  },
681
  }
682
  )