asigalov61 commited on
Commit
49123fc
·
verified ·
1 Parent(s): 7583533

Update x_transformer.py

Browse files
Files changed (1) hide show
  1. x_transformer.py +2 -1
x_transformer.py CHANGED
@@ -27,6 +27,7 @@ from functools import partial
27
  import torch
28
  from torch import nn, einsum, Tensor
29
  import torch.nn.functional as F
 
30
 
31
  from collections import namedtuple
32
  from functools import wraps
@@ -206,7 +207,7 @@ class Attend(nn.Module):
206
 
207
  # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
208
 
209
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
210
  out = F.scaled_dot_product_attention(
211
  q, k, v,
212
  attn_mask = mask,
 
27
  import torch
28
  from torch import nn, einsum, Tensor
29
  import torch.nn.functional as F
30
+ from torch.nn.attention import SDPBackend, sdpa_kernel
31
 
32
  from collections import namedtuple
33
  from functools import wraps
 
207
 
208
  # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
209
 
210
+ with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
211
  out = F.scaled_dot_product_attention(
212
  q, k, v,
213
  attn_mask = mask,