add explicit cast where running without autocast causes issues
Browse files- attention.py +1 -1
attention.py
CHANGED
@@ -55,7 +55,7 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_
|
|
55 |
attn_weight = torch.softmax(attn_weight, dim=-1)
|
56 |
if dropout_p:
|
57 |
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
58 |
-
out = attn_weight.matmul(v)
|
59 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
60 |
if needs_weights:
|
61 |
return (out, attn_weight, past_key_value)
|
|
|
55 |
attn_weight = torch.softmax(attn_weight, dim=-1)
|
56 |
if dropout_p:
|
57 |
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
58 |
+
out = attn_weight.to(v.dtype).matmul(v)
|
59 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
60 |
if needs_weights:
|
61 |
return (out, attn_weight, past_key_value)
|