boris commited on
Commit
472c4cc
·
1 Parent(s): 3b8d8cb

feat: add cogview

Browse files
README.md CHANGED
@@ -124,6 +124,7 @@ Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-t
124
  - "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
125
  - "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
126
  - "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
 
127
  - "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
128
 
129
  Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
@@ -225,6 +226,17 @@ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization f
225
  }
226
  ```
227
 
 
 
 
 
 
 
 
 
 
 
 
228
  ```text
229
  @misc{zhang2019root,
230
  title = {Root Mean Square Layer Normalization},
 
124
  - "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
125
  - "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
126
  - "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
127
+ - "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)
128
  - "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
129
 
130
  Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
 
226
  }
227
  ```
228
 
229
+ ```text
230
+ @misc{ding2021cogview,
231
+ title = {CogView: Mastering Text-to-Image Generation via Transformers},
232
+ author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
233
+ year = {2021},
234
+ eprint = {2105.13290},
235
+ archivePrefix = {arXiv},
236
+ primaryClass = {cs.CV}
237
+ }
238
+ ```
239
+
240
  ```text
241
  @misc{zhang2019root,
242
  title = {Root Mean Square Layer Normalization},
src/dalle_mini/model/configuration.py CHANGED
@@ -60,7 +60,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
60
  # transformer variants
61
  head_scale=False, # used in NormFormer
62
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63
- ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "deepnet" (same as post-ln)
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
@@ -80,6 +80,7 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
80
  assert ln_positions in [
81
  "normformer",
82
  "swinv2",
 
83
  "deepnet",
84
  ], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
85
  self.ln_positions = ln_positions
 
60
  # transformer variants
61
  head_scale=False, # used in NormFormer
62
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63
+ ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "cogview", "deepnet" (same as post-ln)
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
 
80
  assert ln_positions in [
81
  "normformer",
82
  "swinv2",
83
+ "cogview",
84
  "deepnet",
85
  ], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
86
  self.ln_positions = ln_positions
src/dalle_mini/model/modeling.py CHANGED
@@ -373,7 +373,7 @@ class GLU(nn.Module):
373
  self.config
374
  )
375
 
376
- if self.config.ln_positions in ["normformer"]:
377
  x = norm(
378
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
379
  )(x)
@@ -411,7 +411,7 @@ class GLU(nn.Module):
411
  if self.config.use_deepnet_scaling
412
  else jax.nn.initializers.normal(self.config.init_std),
413
  )(x)
414
- if self.config.ln_positions in ["swinv2"]:
415
  x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
416
  x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
417
  return x
@@ -432,7 +432,7 @@ class FFN(nn.Module):
432
  gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
433
  self.config
434
  )
435
- if self.config.ln_positions in ["normformer"]:
436
  x = norm(
437
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
438
  )(x)
@@ -460,7 +460,7 @@ class FFN(nn.Module):
460
  if self.config.use_deepnet_scaling
461
  else jax.nn.initializers.normal(self.config.init_std),
462
  )(x)
463
- if self.config.ln_positions in ["swinv2"]:
464
  x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
465
  x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
466
  return x
@@ -593,7 +593,7 @@ class FlaxBartDecoderLayer(nn.Module):
593
  residual = hidden_states
594
 
595
  # Self Attention
596
- if self.config.ln_positions in ["normformer"]:
597
  hidden_states = norm(
598
  self.config.ln_type,
599
  dtype=self.dtype,
@@ -615,7 +615,7 @@ class FlaxBartDecoderLayer(nn.Module):
615
  init_cache=init_cache,
616
  )
617
 
618
- if self.config.ln_positions in ["normformer", "swinv2"]:
619
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
620
  hidden_states
621
  )
@@ -632,7 +632,7 @@ class FlaxBartDecoderLayer(nn.Module):
632
  cross_attn_weights = None
633
  if encoder_hidden_states is not None:
634
  residual = hidden_states
635
- if self.config.ln_positions in ["normformer"]:
636
  hidden_states = norm(
637
  self.config.ln_type,
638
  dtype=self.dtype,
@@ -652,7 +652,7 @@ class FlaxBartDecoderLayer(nn.Module):
652
  key_value_states=encoder_hidden_states,
653
  attention_mask=encoder_attention_mask,
654
  )
655
- if self.config.ln_positions in ["normformer", "swinv2"]:
656
  hidden_states = norm(
657
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
658
  )(hidden_states)
 
373
  self.config
374
  )
375
 
376
+ if self.config.ln_positions in ["normformer", "cogview"]:
377
  x = norm(
378
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
379
  )(x)
 
411
  if self.config.use_deepnet_scaling
412
  else jax.nn.initializers.normal(self.config.init_std),
413
  )(x)
414
+ if self.config.ln_positions in ["swinv2", "cogview"]:
415
  x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
416
  x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
417
  return x
 
432
  gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
433
  self.config
434
  )
435
+ if self.config.ln_positions in ["normformer", "cogview"]:
436
  x = norm(
437
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
438
  )(x)
 
460
  if self.config.use_deepnet_scaling
461
  else jax.nn.initializers.normal(self.config.init_std),
462
  )(x)
463
+ if self.config.ln_positions in ["swinv2", "cogview"]:
464
  x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
465
  x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
466
  return x
 
593
  residual = hidden_states
594
 
595
  # Self Attention
596
+ if self.config.ln_positions in ["normformer", "cogview"]:
597
  hidden_states = norm(
598
  self.config.ln_type,
599
  dtype=self.dtype,
 
615
  init_cache=init_cache,
616
  )
617
 
618
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
619
  hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
620
  hidden_states
621
  )
 
632
  cross_attn_weights = None
633
  if encoder_hidden_states is not None:
634
  residual = hidden_states
635
+ if self.config.ln_positions in ["normformer", "cogview"]:
636
  hidden_states = norm(
637
  self.config.ln_type,
638
  dtype=self.dtype,
 
652
  key_value_states=encoder_hidden_states,
653
  attention_mask=encoder_attention_mask,
654
  )
655
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
656
  hidden_states = norm(
657
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
658
  )(hidden_states)