yslan commited on
Commit
faf0d54
·
1 Parent(s): e7e3673
Files changed (1) hide show
  1. vit/vision_transformer.py +4 -5
vit/vision_transformer.py CHANGED
@@ -66,7 +66,7 @@ from packaging import version
66
  assert version.parse(torch.__version__) >= version.parse("2.0.0")
67
  SDP_IS_AVAILABLE = True
68
  # from torch.backends.cuda import SDPBackend, sdp_kernel
69
- from torch.nn.attention import sdpa_kernel, SDPBackend
70
 
71
 
72
  class Attention(nn.Module):
@@ -110,7 +110,7 @@ class Attention(nn.Module):
110
  self.no_flash_op = no_flash_op
111
  self.attn_mode = "torch"
112
 
113
- self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch.
114
 
115
  @staticmethod
116
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
@@ -198,9 +198,8 @@ class Attention(nn.Module):
198
  q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
199
  q, k = self.q_norm(q), self.k_norm(k)
200
 
201
- with sdpa_kernel([self.backend]): # new signature
202
-
203
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
204
 
205
  del q, k, v
206
  x = rearrange(x, "B H L D -> B L (H D)")
 
66
  assert version.parse(torch.__version__) >= version.parse("2.0.0")
67
  SDP_IS_AVAILABLE = True
68
  # from torch.backends.cuda import SDPBackend, sdp_kernel
69
+ # from torch.nn.attention import sdpa_kernel, SDPBackend
70
 
71
 
72
  class Attention(nn.Module):
 
110
  self.no_flash_op = no_flash_op
111
  self.attn_mode = "torch"
112
 
113
+ # self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch.
114
 
115
  @staticmethod
116
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
 
198
  q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
199
  q, k = self.q_norm(q), self.k_norm(k)
200
 
201
+ # with sdpa_kernel([self.backend]): # new signature
202
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
 
203
 
204
  del q, k, v
205
  x = rearrange(x, "B H L D -> B L (H D)")