Spaces:
Running
Running
Shuming Ma
Shuming Ma
commited on
Commit
•
503d6b4
1
Parent(s):
02824a7
fix: DeepNet doesn't scale weights of embedding/output layers (#150)
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -883,9 +883,7 @@ class FlaxBartEncoder(FlaxBartEncoder):
|
|
883 |
self.embed_positions = nn.Embed(
|
884 |
self.config.max_text_length + self.offset,
|
885 |
embed_dim,
|
886 |
-
embedding_init=
|
887 |
-
if self.config.use_deepnet_scaling
|
888 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
889 |
)
|
890 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
891 |
self.layernorm_embedding = norm(
|
@@ -917,9 +915,7 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
917 |
self.embed_positions = nn.Embed(
|
918 |
self.config.image_length + self.offset, # image length for BOS
|
919 |
embed_dim,
|
920 |
-
embedding_init=
|
921 |
-
if self.config.use_deepnet_scaling
|
922 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
923 |
)
|
924 |
|
925 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
@@ -939,16 +935,12 @@ class FlaxBartModule(FlaxBartModule):
|
|
939 |
encoder_embed_tokens = nn.Embed(
|
940 |
self.config.encoder_vocab_size,
|
941 |
self.config.d_model,
|
942 |
-
embedding_init=
|
943 |
-
if self.config.use_deepnet_scaling
|
944 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
945 |
)
|
946 |
decoder_embed_tokens = nn.Embed(
|
947 |
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
948 |
self.config.d_model,
|
949 |
-
embedding_init=
|
950 |
-
if self.config.use_deepnet_scaling
|
951 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
952 |
)
|
953 |
|
954 |
self.encoder = FlaxBartEncoder(
|
@@ -1288,9 +1280,7 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
|
|
1288 |
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
1289 |
use_bias=False,
|
1290 |
dtype=self.dtype,
|
1291 |
-
kernel_init=
|
1292 |
-
if self.config.use_deepnet_scaling
|
1293 |
-
else jax.nn.initializers.normal(self.config.init_std),
|
1294 |
)
|
1295 |
|
1296 |
def __call__(
|
|
|
883 |
self.embed_positions = nn.Embed(
|
884 |
self.config.max_text_length + self.offset,
|
885 |
embed_dim,
|
886 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
887 |
)
|
888 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
889 |
self.layernorm_embedding = norm(
|
|
|
915 |
self.embed_positions = nn.Embed(
|
916 |
self.config.image_length + self.offset, # image length for BOS
|
917 |
embed_dim,
|
918 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
919 |
)
|
920 |
|
921 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
|
|
935 |
encoder_embed_tokens = nn.Embed(
|
936 |
self.config.encoder_vocab_size,
|
937 |
self.config.d_model,
|
938 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
939 |
)
|
940 |
decoder_embed_tokens = nn.Embed(
|
941 |
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
942 |
self.config.d_model,
|
943 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
944 |
)
|
945 |
|
946 |
self.encoder = FlaxBartEncoder(
|
|
|
1280 |
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
1281 |
use_bias=False,
|
1282 |
dtype=self.dtype,
|
1283 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
|
|
|
1284 |
)
|
1285 |
|
1286 |
def __call__(
|