Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Update x_transformer.py
Browse files- x_transformer.py +2 -1
x_transformer.py
CHANGED
@@ -27,6 +27,7 @@ from functools import partial
|
|
27 |
import torch
|
28 |
from torch import nn, einsum, Tensor
|
29 |
import torch.nn.functional as F
|
|
|
30 |
|
31 |
from collections import namedtuple
|
32 |
from functools import wraps
|
@@ -206,7 +207,7 @@ class Attend(nn.Module):
|
|
206 |
|
207 |
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
208 |
|
209 |
-
with
|
210 |
out = F.scaled_dot_product_attention(
|
211 |
q, k, v,
|
212 |
attn_mask = mask,
|
|
|
27 |
import torch
|
28 |
from torch import nn, einsum, Tensor
|
29 |
import torch.nn.functional as F
|
30 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
31 |
|
32 |
from collections import namedtuple
|
33 |
from functools import wraps
|
|
|
207 |
|
208 |
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
209 |
|
210 |
+
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
211 |
out = F.scaled_dot_product_attention(
|
212 |
q, k, v,
|
213 |
attn_mask = mask,
|