Spaces:
Running
Running
feat: sinkhorn in lse mode (#155)
Browse files- src/dalle_mini/model/modeling.py +18 -26
src/dalle_mini/model/modeling.py
CHANGED
@@ -187,9 +187,11 @@ def dot_product_attention_weights(
|
|
187 |
dtype: Any = jnp.float32,
|
188 |
precision: PrecisionLike = None,
|
189 |
sinkhorn_iters: int = 1,
|
|
|
190 |
):
|
191 |
"""
|
192 |
Computes dot-product attention weights given query and key.
|
|
|
193 |
|
194 |
Adapted from flax.linen.attention.dot_product_attention_weights"
|
195 |
"""
|
@@ -207,33 +209,22 @@ def dot_product_attention_weights(
|
|
207 |
# apply attention bias: masking, dropout, proximity bias, etc.
|
208 |
if bias is not None:
|
209 |
attn_weights = attn_weights + bias
|
210 |
-
# apply attention mask
|
211 |
-
if mask is not None:
|
212 |
-
big_neg = jnp.finfo(dtype).min
|
213 |
-
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
214 |
|
215 |
# normalize the attention weights
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
0.0,
|
231 |
-
)
|
232 |
-
else:
|
233 |
-
attn_weights = attn_weights / (
|
234 |
-
1e-5
|
235 |
-
+ jax.lax.stop_gradient(jnp.sum(attn_weights, axis=axis, keepdims=True))
|
236 |
-
)
|
237 |
|
238 |
# apply attention dropout
|
239 |
if not deterministic and dropout_rate > 0.0:
|
@@ -392,7 +383,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
392 |
attention_bias = lax.select(
|
393 |
attention_mask > 0,
|
394 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
395 |
-
jnp.full(attention_mask.shape,
|
396 |
)
|
397 |
else:
|
398 |
attention_bias = None
|
@@ -421,6 +412,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
421 |
dtype=self.dtype,
|
422 |
precision=None,
|
423 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
|
|
424 |
)
|
425 |
if self.config.use_cosine_attention:
|
426 |
# divide by tau
|
|
|
187 |
dtype: Any = jnp.float32,
|
188 |
precision: PrecisionLike = None,
|
189 |
sinkhorn_iters: int = 1,
|
190 |
+
causal: bool = False,
|
191 |
):
|
192 |
"""
|
193 |
Computes dot-product attention weights given query and key.
|
194 |
+
mask is included into the bias.
|
195 |
|
196 |
Adapted from flax.linen.attention.dot_product_attention_weights"
|
197 |
"""
|
|
|
209 |
# apply attention bias: masking, dropout, proximity bias, etc.
|
210 |
if bias is not None:
|
211 |
attn_weights = attn_weights + bias
|
|
|
|
|
|
|
|
|
212 |
|
213 |
# normalize the attention weights
|
214 |
+
if causal or sinkhorn_iters == 1:
|
215 |
+
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
216 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
217 |
+
else:
|
218 |
+
# adapted from https://github.com/lucidrains/sinkhorn-transformer
|
219 |
+
for i in range(sinkhorn_iters):
|
220 |
+
# when causal, some attn_weights have been set to -inf through bias
|
221 |
+
if i % 2 == 0:
|
222 |
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
|
223 |
+
else:
|
224 |
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
|
225 |
+
if mask is not None:
|
226 |
+
attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
|
227 |
+
attn_weights = jnp.exp(attn_weights).astype(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
# apply attention dropout
|
230 |
if not deterministic and dropout_rate > 0.0:
|
|
|
383 |
attention_bias = lax.select(
|
384 |
attention_mask > 0,
|
385 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
386 |
+
jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
|
387 |
)
|
388 |
else:
|
389 |
attention_bias = None
|
|
|
412 |
dtype=self.dtype,
|
413 |
precision=None,
|
414 |
sinkhorn_iters=self.config.sinkhorn_iters,
|
415 |
+
causal=self.causal,
|
416 |
)
|
417 |
if self.config.use_cosine_attention:
|
418 |
# divide by tau
|