Spaces:
Running
Running
fix: sinkformer
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -211,7 +211,7 @@ def dot_product_attention_weights(
|
|
211 |
dtype: Any = jnp.float32,
|
212 |
precision: PrecisionLike = None,
|
213 |
sinkhorn_iters: int = 1,
|
214 |
-
|
215 |
):
|
216 |
"""
|
217 |
Computes dot-product attention weights given query and key.
|
@@ -239,7 +239,7 @@ def dot_product_attention_weights(
|
|
239 |
attn_weights = attn_weights + embed_pos
|
240 |
|
241 |
# normalize the attention weights
|
242 |
-
if
|
243 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
244 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
245 |
else:
|
@@ -461,7 +461,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
461 |
dtype=self.dtype,
|
462 |
precision=None,
|
463 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
464 |
-
|
465 |
)
|
466 |
if self.config.use_cosine_attention:
|
467 |
# divide by tau
|
|
|
211 |
dtype: Any = jnp.float32,
|
212 |
precision: PrecisionLike = None,
|
213 |
sinkhorn_iters: int = 1,
|
214 |
+
is_encoder: bool = False,
|
215 |
):
|
216 |
"""
|
217 |
Computes dot-product attention weights given query and key.
|
|
|
239 |
attn_weights = attn_weights + embed_pos
|
240 |
|
241 |
# normalize the attention weights
|
242 |
+
if not is_encoder or sinkhorn_iters == 1:
|
243 |
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
244 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
245 |
else:
|
|
|
461 |
dtype=self.dtype,
|
462 |
precision=None,
|
463 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
464 |
+
is_encoder=self.is_encoder,
|
465 |
)
|
466 |
if self.config.use_cosine_attention:
|
467 |
# divide by tau
|