Spaces:
Running
Running
fix: sinkformer gradient
Browse files
src/dalle_mini/model/modeling.py
CHANGED
@@ -215,8 +215,25 @@ def dot_product_attention_weights(
|
|
215 |
# normalize the attention weights
|
216 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
217 |
for i in range(sinkhorn_iters - 1):
|
|
|
218 |
axis = -2 if i % 2 == 0 else -1
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
# apply attention dropout
|
222 |
if not deterministic and dropout_rate > 0.0:
|
@@ -396,6 +413,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
396 |
query_states,
|
397 |
key_states,
|
398 |
bias=attention_bias,
|
|
|
399 |
dropout_rng=dropout_rng,
|
400 |
dropout_rate=self.dropout,
|
401 |
broadcast_dropout=True,
|
|
|
215 |
# normalize the attention weights
|
216 |
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
217 |
for i in range(sinkhorn_iters - 1):
|
218 |
+
# TODO: this is unstable, requires lse space
|
219 |
axis = -2 if i % 2 == 0 else -1
|
220 |
+
if mask is not None:
|
221 |
+
attn_weights = jnp.where(
|
222 |
+
mask > 0,
|
223 |
+
attn_weights
|
224 |
+
/ (
|
225 |
+
1e-5
|
226 |
+
+ jax.lax.stop_gradient(
|
227 |
+
jnp.sum(attn_weights, axis=axis, where=mask, keepdims=True)
|
228 |
+
)
|
229 |
+
),
|
230 |
+
0.0,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
attn_weights = attn_weights / (
|
234 |
+
1e-5
|
235 |
+
+ jax.lax.stop_gradient(jnp.sum(attn_weights, axis=axis, keepdims=True))
|
236 |
+
)
|
237 |
|
238 |
# apply attention dropout
|
239 |
if not deterministic and dropout_rate > 0.0:
|
|
|
413 |
query_states,
|
414 |
key_states,
|
415 |
bias=attention_bias,
|
416 |
+
mask=attention_mask,
|
417 |
dropout_rng=dropout_rng,
|
418 |
dropout_rate=self.dropout,
|
419 |
broadcast_dropout=True,
|