Spaces:
Running
Running
feat: allow relative position (#156)
Browse files
src/dalle_mini/model/configuration.py
CHANGED
@@ -64,12 +64,14 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
64 |
use_head_scale=False, # used in NormFormer
|
65 |
use_cosine_attention=False, # used in Swin v2
|
66 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
|
|
|
|
67 |
use_deepnet_scaling=False, # used in Deepnet
|
68 |
use_glu=False, # "GLU Variants Improve Transformer"
|
69 |
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
70 |
sinkhorn_iters=1, # used in SinkFormers
|
71 |
-
use_final_ln_encoder=
|
72 |
-
use_final_ln_decoder=
|
73 |
# parameters that should not be necessary but could affect results
|
74 |
force_ln_scale=False, # force scale in layernorm even when followed by dense layers
|
75 |
**kwargs,
|
@@ -98,6 +100,8 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
98 |
self.ln_positions = ln_positions
|
99 |
self.use_cosine_attention = use_cosine_attention
|
100 |
self.tau_init = tau_init
|
|
|
|
|
101 |
self.use_deepnet_scaling = use_deepnet_scaling
|
102 |
self.use_glu = use_glu
|
103 |
self.use_alibi = use_alibi
|
|
|
64 |
use_head_scale=False, # used in NormFormer
|
65 |
use_cosine_attention=False, # used in Swin v2
|
66 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
67 |
+
use_absolute_position_embeddings=True, # default
|
68 |
+
use_swin_position_embeddings=False, # used in Swin v1/v2
|
69 |
use_deepnet_scaling=False, # used in Deepnet
|
70 |
use_glu=False, # "GLU Variants Improve Transformer"
|
71 |
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
72 |
sinkhorn_iters=1, # used in SinkFormers
|
73 |
+
use_final_ln_encoder=True, # final layer normalization in encoder
|
74 |
+
use_final_ln_decoder=True, # final layer normalization in decoder
|
75 |
# parameters that should not be necessary but could affect results
|
76 |
force_ln_scale=False, # force scale in layernorm even when followed by dense layers
|
77 |
**kwargs,
|
|
|
100 |
self.ln_positions = ln_positions
|
101 |
self.use_cosine_attention = use_cosine_attention
|
102 |
self.tau_init = tau_init
|
103 |
+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
104 |
+
self.use_swin_position_embeddings = use_swin_position_embeddings
|
105 |
self.use_deepnet_scaling = use_deepnet_scaling
|
106 |
self.use_glu = use_glu
|
107 |
self.use_alibi = use_alibi
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -25,6 +25,7 @@ import flax.linen as nn
|
|
25 |
import jax
|
26 |
import jax.numpy as jnp
|
27 |
import msgpack.exceptions
|
|
|
28 |
from flax.core.frozen_dict import unfreeze
|
29 |
from flax.linen import combine_masks, make_causal_mask
|
30 |
from flax.linen import partitioning as nn_partitioning
|
@@ -52,8 +53,6 @@ from transformers.modeling_flax_outputs import (
|
|
52 |
from transformers.modeling_flax_utils import ACT2FN
|
53 |
from transformers.models.bart.modeling_flax_bart import (
|
54 |
FlaxBartAttention,
|
55 |
-
FlaxBartDecoder,
|
56 |
-
FlaxBartEncoder,
|
57 |
FlaxBartForConditionalGeneration,
|
58 |
FlaxBartForConditionalGenerationModule,
|
59 |
FlaxBartModule,
|
@@ -180,6 +179,7 @@ def dot_product_attention_weights(
|
|
180 |
key: Any,
|
181 |
bias: Optional[Any] = None,
|
182 |
mask: Optional[Any] = None,
|
|
|
183 |
broadcast_dropout: bool = True,
|
184 |
dropout_rng: Optional[PRNGKey] = None,
|
185 |
dropout_rate: float = 0.0,
|
@@ -210,6 +210,10 @@ def dot_product_attention_weights(
|
|
210 |
if bias is not None:
|
211 |
attn_weights = attn_weights + bias
|
212 |
|
|
|
|
|
|
|
|
|
213 |
# normalize the attention weights
|
214 |
if causal or sinkhorn_iters == 1:
|
215 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
@@ -251,6 +255,8 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
251 |
"""
|
252 |
|
253 |
is_encoder: bool = False
|
|
|
|
|
254 |
|
255 |
def setup(self) -> None:
|
256 |
self.head_dim = self.embed_dim // self.num_heads
|
@@ -305,6 +311,15 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
305 |
(1, self.num_heads, 1, 1),
|
306 |
)
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
if self.causal:
|
309 |
# used only in decoder
|
310 |
self.causal_mask = make_causal_mask(
|
@@ -400,11 +415,21 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
400 |
key_states = key_states / (
|
401 |
jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
|
402 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
attn_weights = dot_product_attention_weights(
|
404 |
query_states,
|
405 |
key_states,
|
406 |
bias=attention_bias,
|
407 |
mask=attention_mask,
|
|
|
408 |
dropout_rng=dropout_rng,
|
409 |
dropout_rate=self.dropout,
|
410 |
broadcast_dropout=True,
|
@@ -593,6 +618,8 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
593 |
bias=self.config.use_bias,
|
594 |
dtype=self.dtype,
|
595 |
is_encoder=True,
|
|
|
|
|
596 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
597 |
|
598 |
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
@@ -699,6 +726,8 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
699 |
bias=self.config.use_bias,
|
700 |
dtype=self.dtype,
|
701 |
is_encoder=False,
|
|
|
|
|
702 |
)(
|
703 |
hidden_states=hidden_states,
|
704 |
attention_mask=attention_mask,
|
@@ -737,6 +766,8 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
737 |
bias=self.config.use_bias,
|
738 |
dtype=self.dtype,
|
739 |
is_encoder=False,
|
|
|
|
|
740 |
)(
|
741 |
hidden_states=hidden_states,
|
742 |
key_value_states=encoder_hidden_states,
|
@@ -953,7 +984,10 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
953 |
)
|
954 |
|
955 |
|
956 |
-
class FlaxBartEncoder(
|
|
|
|
|
|
|
957 |
"""
|
958 |
Edits:
|
959 |
- offset set to 0 (no padding token)
|
@@ -972,18 +1006,62 @@ class FlaxBartEncoder(FlaxBartEncoder):
|
|
972 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
973 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
974 |
self.offset = 0
|
975 |
-
self.
|
976 |
-
self.
|
977 |
-
|
978 |
-
|
979 |
-
|
|
|
980 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
981 |
self.layernorm_embedding = norm(
|
982 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
983 |
)
|
984 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
985 |
|
986 |
-
class FlaxBartDecoder(
|
|
|
|
|
|
|
987 |
"""
|
988 |
Edits:
|
989 |
- offset set to 0 (no padding token)
|
@@ -1004,17 +1082,65 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
1004 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
1005 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
1006 |
self.offset = 0
|
1007 |
-
self.
|
1008 |
-
self.
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
|
|
1012 |
|
1013 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
1014 |
self.layernorm_embedding = norm(
|
1015 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1016 |
)
|
1017 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1018 |
|
1019 |
class FlaxBartModule(FlaxBartModule):
|
1020 |
"""
|
|
|
25 |
import jax
|
26 |
import jax.numpy as jnp
|
27 |
import msgpack.exceptions
|
28 |
+
from einops import rearrange
|
29 |
from flax.core.frozen_dict import unfreeze
|
30 |
from flax.linen import combine_masks, make_causal_mask
|
31 |
from flax.linen import partitioning as nn_partitioning
|
|
|
53 |
from transformers.modeling_flax_utils import ACT2FN
|
54 |
from transformers.models.bart.modeling_flax_bart import (
|
55 |
FlaxBartAttention,
|
|
|
|
|
56 |
FlaxBartForConditionalGeneration,
|
57 |
FlaxBartForConditionalGenerationModule,
|
58 |
FlaxBartModule,
|
|
|
179 |
key: Any,
|
180 |
bias: Optional[Any] = None,
|
181 |
mask: Optional[Any] = None,
|
182 |
+
embed_pos: Optional[Any] = None,
|
183 |
broadcast_dropout: bool = True,
|
184 |
dropout_rng: Optional[PRNGKey] = None,
|
185 |
dropout_rate: float = 0.0,
|
|
|
210 |
if bias is not None:
|
211 |
attn_weights = attn_weights + bias
|
212 |
|
213 |
+
# add relative position
|
214 |
+
if embed_pos is not None:
|
215 |
+
attn_weights = attn_weights + embed_pos
|
216 |
+
|
217 |
# normalize the attention weights
|
218 |
if causal or sinkhorn_iters == 1:
|
219 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
|
|
255 |
"""
|
256 |
|
257 |
is_encoder: bool = False
|
258 |
+
q_length: int = None
|
259 |
+
k_length: int = None
|
260 |
|
261 |
def setup(self) -> None:
|
262 |
self.head_dim = self.embed_dim // self.num_heads
|
|
|
311 |
(1, self.num_heads, 1, 1),
|
312 |
)
|
313 |
|
314 |
+
if self.config.use_swin_position_embeddings:
|
315 |
+
self.rel_bias = nn.Embed(
|
316 |
+
self.q_length,
|
317 |
+
self.k_length * self.num_heads,
|
318 |
+
embedding_init=deepnet_init()
|
319 |
+
if self.config.use_deepnet_scaling
|
320 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
321 |
+
)
|
322 |
+
|
323 |
if self.causal:
|
324 |
# used only in decoder
|
325 |
self.causal_mask = make_causal_mask(
|
|
|
415 |
key_states = key_states / (
|
416 |
jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
|
417 |
)
|
418 |
+
|
419 |
+
# relative position embeddings
|
420 |
+
if self.config.use_swin_position_embeddings:
|
421 |
+
position_ids = jnp.arange(self.q_length)
|
422 |
+
embed_pos = self.rel_bias(position_ids)
|
423 |
+
embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
|
424 |
+
else:
|
425 |
+
embed_pos = None
|
426 |
+
|
427 |
attn_weights = dot_product_attention_weights(
|
428 |
query_states,
|
429 |
key_states,
|
430 |
bias=attention_bias,
|
431 |
mask=attention_mask,
|
432 |
+
embed_pos=embed_pos,
|
433 |
dropout_rng=dropout_rng,
|
434 |
dropout_rate=self.dropout,
|
435 |
broadcast_dropout=True,
|
|
|
618 |
bias=self.config.use_bias,
|
619 |
dtype=self.dtype,
|
620 |
is_encoder=True,
|
621 |
+
q_length=self.config.max_text_length,
|
622 |
+
k_length=self.config.max_text_length,
|
623 |
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
624 |
|
625 |
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
|
|
726 |
bias=self.config.use_bias,
|
727 |
dtype=self.dtype,
|
728 |
is_encoder=False,
|
729 |
+
q_length=self.config.image_length,
|
730 |
+
k_length=self.config.image_length,
|
731 |
)(
|
732 |
hidden_states=hidden_states,
|
733 |
attention_mask=attention_mask,
|
|
|
766 |
bias=self.config.use_bias,
|
767 |
dtype=self.dtype,
|
768 |
is_encoder=False,
|
769 |
+
q_length=self.config.image_length,
|
770 |
+
k_length=self.config.max_text_length,
|
771 |
)(
|
772 |
hidden_states=hidden_states,
|
773 |
key_value_states=encoder_hidden_states,
|
|
|
984 |
)
|
985 |
|
986 |
|
987 |
+
class FlaxBartEncoder(nn.Module):
|
988 |
+
config: DalleBartConfig
|
989 |
+
embed_tokens: nn.Embed
|
990 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
991 |
"""
|
992 |
Edits:
|
993 |
- offset set to 0 (no padding token)
|
|
|
1006 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
1007 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
1008 |
self.offset = 0
|
1009 |
+
if self.config.use_absolute_position_embeddings:
|
1010 |
+
self.embed_positions = nn.Embed(
|
1011 |
+
self.config.max_text_length + self.offset, # image length for BOS
|
1012 |
+
embed_dim,
|
1013 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
1014 |
+
)
|
1015 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
1016 |
self.layernorm_embedding = norm(
|
1017 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1018 |
)
|
1019 |
|
1020 |
+
def __call__(
|
1021 |
+
self,
|
1022 |
+
input_ids,
|
1023 |
+
attention_mask,
|
1024 |
+
position_ids,
|
1025 |
+
output_attentions: bool = False,
|
1026 |
+
output_hidden_states: bool = False,
|
1027 |
+
return_dict: bool = True,
|
1028 |
+
deterministic: bool = True,
|
1029 |
+
):
|
1030 |
+
input_shape = input_ids.shape
|
1031 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
1032 |
+
|
1033 |
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
1034 |
+
|
1035 |
+
if self.config.use_absolute_position_embeddings:
|
1036 |
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
1037 |
+
hidden_states = hidden_states + embed_pos
|
1038 |
+
|
1039 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
1040 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
1041 |
+
|
1042 |
+
outputs = self.layers(
|
1043 |
+
hidden_states,
|
1044 |
+
attention_mask,
|
1045 |
+
deterministic=deterministic,
|
1046 |
+
output_attentions=output_attentions,
|
1047 |
+
output_hidden_states=output_hidden_states,
|
1048 |
+
return_dict=return_dict,
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
if not return_dict:
|
1052 |
+
return outputs
|
1053 |
+
|
1054 |
+
return FlaxBaseModelOutput(
|
1055 |
+
last_hidden_state=outputs.last_hidden_state,
|
1056 |
+
hidden_states=outputs.hidden_states,
|
1057 |
+
attentions=outputs.attentions,
|
1058 |
+
)
|
1059 |
+
|
1060 |
|
1061 |
+
class FlaxBartDecoder(nn.Module):
|
1062 |
+
config: DalleBartConfig
|
1063 |
+
embed_tokens: nn.Embed
|
1064 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
1065 |
"""
|
1066 |
Edits:
|
1067 |
- offset set to 0 (no padding token)
|
|
|
1082 |
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
1083 |
# and adjust num_embeddings appropriately. Other models don't have this hack
|
1084 |
self.offset = 0
|
1085 |
+
if self.config.use_absolute_position_embeddings:
|
1086 |
+
self.embed_positions = nn.Embed(
|
1087 |
+
self.config.image_length + self.offset, # image length for BOS
|
1088 |
+
embed_dim,
|
1089 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
1090 |
+
)
|
1091 |
|
1092 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
1093 |
self.layernorm_embedding = norm(
|
1094 |
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1095 |
)
|
1096 |
|
1097 |
+
def __call__(
|
1098 |
+
self,
|
1099 |
+
input_ids,
|
1100 |
+
attention_mask,
|
1101 |
+
position_ids,
|
1102 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
1103 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1104 |
+
init_cache: bool = False,
|
1105 |
+
output_attentions: bool = False,
|
1106 |
+
output_hidden_states: bool = False,
|
1107 |
+
return_dict: bool = True,
|
1108 |
+
deterministic: bool = True,
|
1109 |
+
):
|
1110 |
+
input_shape = input_ids.shape
|
1111 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
1112 |
+
|
1113 |
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
1114 |
+
|
1115 |
+
if self.config.use_absolute_position_embeddings:
|
1116 |
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
1117 |
+
hidden_states = hidden_states + embed_pos
|
1118 |
+
|
1119 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
1120 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
1121 |
+
|
1122 |
+
outputs = self.layers(
|
1123 |
+
hidden_states,
|
1124 |
+
attention_mask,
|
1125 |
+
encoder_hidden_states,
|
1126 |
+
encoder_attention_mask,
|
1127 |
+
deterministic=deterministic,
|
1128 |
+
init_cache=init_cache,
|
1129 |
+
output_attentions=output_attentions,
|
1130 |
+
output_hidden_states=output_hidden_states,
|
1131 |
+
return_dict=return_dict,
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
if not return_dict:
|
1135 |
+
return outputs
|
1136 |
+
|
1137 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
1138 |
+
last_hidden_state=outputs.last_hidden_state,
|
1139 |
+
hidden_states=outputs.hidden_states,
|
1140 |
+
attentions=outputs.attentions,
|
1141 |
+
cross_attentions=outputs.cross_attentions,
|
1142 |
+
)
|
1143 |
+
|
1144 |
|
1145 |
class FlaxBartModule(FlaxBartModule):
|
1146 |
"""
|
src/dalle_mini/model/partitions.py
CHANGED
@@ -38,6 +38,7 @@ def _get_partition_rules():
|
|
38 |
# embeddings
|
39 |
(("embed_positions", "embedding"), P("mp", None)),
|
40 |
(("embed_tokens", "embedding"), P("mp", None)),
|
|
|
41 |
# attention
|
42 |
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
43 |
(("out_proj", "kernel"), P("mp", None)),
|
|
|
38 |
# embeddings
|
39 |
(("embed_positions", "embedding"), P("mp", None)),
|
40 |
(("embed_tokens", "embedding"), P("mp", None)),
|
41 |
+
(("rel_bias", "embedding"), P(None, "mp")),
|
42 |
# attention
|
43 |
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
44 |
(("out_proj", "kernel"), P("mp", None)),
|