Spaces:
Running
Running
fix layernorm
Browse files
dalle_mini/modeling_bart_flax.py
CHANGED
@@ -329,7 +329,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
329 |
dropout=self.config.attention_dropout,
|
330 |
dtype=self.dtype,
|
331 |
)
|
332 |
-
self.encoder_attn_layer_norm = nn
|
333 |
self.fc1 = nn.Dense(
|
334 |
self.config.encoder_ffn_dim,
|
335 |
dtype=self.dtype,
|
|
|
329 |
dropout=self.config.attention_dropout,
|
330 |
dtype=self.dtype,
|
331 |
)
|
332 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
|
333 |
self.fc1 = nn.Dense(
|
334 |
self.config.encoder_ffn_dim,
|
335 |
dtype=self.dtype,
|