Spaces:
Running
Running
feat: allow more configurations
Browse files
src/dalle_mini/model/configuration.py
CHANGED
@@ -58,13 +58,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
58 |
tie_word_embeddings=False, # different modalities and sizes
|
59 |
do_sample=True,
|
60 |
# transformer variants
|
61 |
-
head_scale=False, # used in NormFormer
|
62 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
63 |
-
ln_positions="
|
|
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
67 |
-
use_glu=
|
|
|
68 |
**kwargs,
|
69 |
):
|
70 |
# text normalizer
|
@@ -83,11 +84,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
83 |
"cogview",
|
84 |
"deepnet",
|
85 |
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
|
|
|
|
|
86 |
self.ln_positions = ln_positions
|
87 |
self.use_cosine_attention = use_cosine_attention
|
88 |
self.tau_init = tau_init
|
89 |
self.use_deepnet_scaling = use_deepnet_scaling
|
90 |
self.use_glu = use_glu
|
|
|
91 |
|
92 |
# common parameters
|
93 |
self.encoder_vocab_size = encoder_vocab_size
|
|
|
58 |
tie_word_embeddings=False, # different modalities and sizes
|
59 |
do_sample=True,
|
60 |
# transformer variants
|
|
|
61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
62 |
+
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
63 |
+
head_scale=True, # used in NormFormer
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
67 |
+
use_glu=True, # "GLU Variants Improve Transformer"
|
68 |
+
use_all_scale=True, # use scale in layernorm even when seemingly unnecessary
|
69 |
**kwargs,
|
70 |
):
|
71 |
# text normalizer
|
|
|
84 |
"cogview",
|
85 |
"deepnet",
|
86 |
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
|
87 |
+
if ln_positions == "deepnet":
|
88 |
+
ln_positions = "postln"
|
89 |
self.ln_positions = ln_positions
|
90 |
self.use_cosine_attention = use_cosine_attention
|
91 |
self.tau_init = tau_init
|
92 |
self.use_deepnet_scaling = use_deepnet_scaling
|
93 |
self.use_glu = use_glu
|
94 |
+
self.use_all_scale = use_all_scale
|
95 |
|
96 |
# common parameters
|
97 |
self.encoder_vocab_size = encoder_vocab_size
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -375,7 +375,10 @@ class GLU(nn.Module):
|
|
375 |
|
376 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
377 |
x = norm(
|
378 |
-
self.config.ln_type,
|
|
|
|
|
|
|
379 |
)(x)
|
380 |
w = nn.Dense(
|
381 |
self.ffn_dim,
|
@@ -397,7 +400,10 @@ class GLU(nn.Module):
|
|
397 |
x = w * v
|
398 |
if self.config.ln_positions in ["normformer"]:
|
399 |
x = norm(
|
400 |
-
self.config.ln_type,
|
|
|
|
|
|
|
401 |
)(x)
|
402 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
403 |
x, deterministic=deterministic
|
@@ -434,7 +440,10 @@ class FFN(nn.Module):
|
|
434 |
)
|
435 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
436 |
x = norm(
|
437 |
-
self.config.ln_type,
|
|
|
|
|
|
|
438 |
)(x)
|
439 |
x = nn.Dense(
|
440 |
self.ffn_dim,
|
@@ -447,7 +456,10 @@ class FFN(nn.Module):
|
|
447 |
x = ACT2FN[self.config.activation_function](x)
|
448 |
if self.config.ln_positions in ["normformer"]:
|
449 |
x = norm(
|
450 |
-
self.config.ln_type,
|
|
|
|
|
|
|
451 |
)(x)
|
452 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
453 |
x, deterministic=deterministic
|
@@ -495,10 +507,13 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
495 |
|
496 |
embed_dim = self.config.d_model
|
497 |
residual = hidden_states
|
498 |
-
if self.config.ln_positions in ["normformer"]:
|
499 |
-
hidden_states = norm(
|
500 |
-
|
501 |
-
|
|
|
|
|
|
|
502 |
hidden_states, attn_weights = FlaxBartAttention(
|
503 |
config=self.config,
|
504 |
embed_dim=embed_dim,
|
@@ -509,7 +524,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
509 |
is_encoder=True,
|
510 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
511 |
|
512 |
-
if self.config.ln_positions in ["normformer", "swinv2"]:
|
513 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
514 |
hidden_states
|
515 |
)
|
@@ -517,7 +532,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
517 |
hidden_states, deterministic=deterministic
|
518 |
)
|
519 |
hidden_states = residual * res_gain + hidden_states
|
520 |
-
if self.config.ln_positions in ["
|
521 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
522 |
hidden_states
|
523 |
)
|
@@ -542,8 +557,12 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
542 |
)
|
543 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
544 |
hidden_states = residual * res_gain + hidden_states
|
545 |
-
if self.add_norm or self.config.ln_positions in ["
|
546 |
-
use_scale =
|
|
|
|
|
|
|
|
|
547 |
hidden_states = norm(
|
548 |
self.config.ln_type,
|
549 |
dtype=self.dtype,
|
@@ -598,7 +617,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
598 |
self.config.ln_type,
|
599 |
dtype=self.dtype,
|
600 |
epsilon=1e-05,
|
601 |
-
use_scale=
|
602 |
)(hidden_states)
|
603 |
hidden_states, attn_weights = FlaxBartAttention(
|
604 |
config=self.config,
|
@@ -623,7 +642,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
623 |
hidden_states, deterministic=deterministic
|
624 |
)
|
625 |
hidden_states = residual * res_gain + hidden_states
|
626 |
-
if self.config.ln_positions in ["
|
627 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
628 |
hidden_states
|
629 |
)
|
@@ -637,7 +656,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
637 |
self.config.ln_type,
|
638 |
dtype=self.dtype,
|
639 |
epsilon=1e-05,
|
640 |
-
use_scale=
|
641 |
)(hidden_states)
|
642 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
643 |
config=self.config,
|
@@ -660,7 +679,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
660 |
hidden_states, deterministic=deterministic
|
661 |
)
|
662 |
hidden_states = residual * res_gain + hidden_states
|
663 |
-
if self.config.ln_positions in ["
|
664 |
hidden_states = norm(
|
665 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
666 |
)(hidden_states)
|
@@ -686,8 +705,12 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
686 |
)
|
687 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
688 |
hidden_states = residual * res_gain + hidden_states
|
689 |
-
if self.add_norm or self.config.ln_positions in ["
|
690 |
-
use_scale =
|
|
|
|
|
|
|
|
|
691 |
hidden_states = norm(
|
692 |
self.config.ln_type,
|
693 |
dtype=self.dtype,
|
|
|
375 |
|
376 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
377 |
x = norm(
|
378 |
+
self.config.ln_type,
|
379 |
+
dtype=self.dtype,
|
380 |
+
epsilon=1e-05,
|
381 |
+
use_scale=self.config.use_all_scale,
|
382 |
)(x)
|
383 |
w = nn.Dense(
|
384 |
self.ffn_dim,
|
|
|
400 |
x = w * v
|
401 |
if self.config.ln_positions in ["normformer"]:
|
402 |
x = norm(
|
403 |
+
self.config.ln_type,
|
404 |
+
dtype=self.dtype,
|
405 |
+
epsilon=1e-05,
|
406 |
+
use_scale=self.config.use_all_scale,
|
407 |
)(x)
|
408 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
409 |
x, deterministic=deterministic
|
|
|
440 |
)
|
441 |
if self.config.ln_positions in ["normformer", "cogview"]:
|
442 |
x = norm(
|
443 |
+
self.config.ln_type,
|
444 |
+
dtype=self.dtype,
|
445 |
+
epsilon=1e-05,
|
446 |
+
use_scale=self.config.use_all_scale,
|
447 |
)(x)
|
448 |
x = nn.Dense(
|
449 |
self.ffn_dim,
|
|
|
456 |
x = ACT2FN[self.config.activation_function](x)
|
457 |
if self.config.ln_positions in ["normformer"]:
|
458 |
x = norm(
|
459 |
+
self.config.ln_type,
|
460 |
+
dtype=self.dtype,
|
461 |
+
epsilon=1e-05,
|
462 |
+
use_scale=self.config.use_all_scale,
|
463 |
)(x)
|
464 |
x = nn.Dropout(rate=self.config.activation_dropout)(
|
465 |
x, deterministic=deterministic
|
|
|
507 |
|
508 |
embed_dim = self.config.d_model
|
509 |
residual = hidden_states
|
510 |
+
if self.config.ln_positions in ["normformer", "cogview"]:
|
511 |
+
hidden_states = norm(
|
512 |
+
self.config.ln_type,
|
513 |
+
dtype=self.dtype,
|
514 |
+
epsilon=1e-05,
|
515 |
+
use_scale=self.config.use_all_scale,
|
516 |
+
)(hidden_states)
|
517 |
hidden_states, attn_weights = FlaxBartAttention(
|
518 |
config=self.config,
|
519 |
embed_dim=embed_dim,
|
|
|
524 |
is_encoder=True,
|
525 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
526 |
|
527 |
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
528 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
529 |
hidden_states
|
530 |
)
|
|
|
532 |
hidden_states, deterministic=deterministic
|
533 |
)
|
534 |
hidden_states = residual * res_gain + hidden_states
|
535 |
+
if self.config.ln_positions in ["postln"]:
|
536 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
537 |
hidden_states
|
538 |
)
|
|
|
557 |
)
|
558 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
559 |
hidden_states = residual * res_gain + hidden_states
|
560 |
+
if self.add_norm or self.config.ln_positions in ["postln"]:
|
561 |
+
use_scale = (
|
562 |
+
self.use_scale
|
563 |
+
or self.config.ln_positions == "postln"
|
564 |
+
or self.config.use_all_scale
|
565 |
+
)
|
566 |
hidden_states = norm(
|
567 |
self.config.ln_type,
|
568 |
dtype=self.dtype,
|
|
|
617 |
self.config.ln_type,
|
618 |
dtype=self.dtype,
|
619 |
epsilon=1e-05,
|
620 |
+
use_scale=self.config.use_all_scale,
|
621 |
)(hidden_states)
|
622 |
hidden_states, attn_weights = FlaxBartAttention(
|
623 |
config=self.config,
|
|
|
642 |
hidden_states, deterministic=deterministic
|
643 |
)
|
644 |
hidden_states = residual * res_gain + hidden_states
|
645 |
+
if self.config.ln_positions in ["postln"]:
|
646 |
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
647 |
hidden_states
|
648 |
)
|
|
|
656 |
self.config.ln_type,
|
657 |
dtype=self.dtype,
|
658 |
epsilon=1e-05,
|
659 |
+
use_scale=self.config.use_all_scale,
|
660 |
)(hidden_states)
|
661 |
hidden_states, cross_attn_weights = FlaxBartAttention(
|
662 |
config=self.config,
|
|
|
679 |
hidden_states, deterministic=deterministic
|
680 |
)
|
681 |
hidden_states = residual * res_gain + hidden_states
|
682 |
+
if self.config.ln_positions in ["postln"]:
|
683 |
hidden_states = norm(
|
684 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
685 |
)(hidden_states)
|
|
|
705 |
)
|
706 |
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
707 |
hidden_states = residual * res_gain + hidden_states
|
708 |
+
if self.add_norm or self.config.ln_positions in ["postln"]:
|
709 |
+
use_scale = (
|
710 |
+
self.use_scale
|
711 |
+
or self.config.ln_positions == "postln"
|
712 |
+
or self.config.use_all_scale
|
713 |
+
)
|
714 |
hidden_states = norm(
|
715 |
self.config.ln_type,
|
716 |
dtype=self.dtype,
|