Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files
vit/vision_transformer.py
CHANGED
@@ -66,7 +66,7 @@ from packaging import version
|
|
66 |
assert version.parse(torch.__version__) >= version.parse("2.0.0")
|
67 |
SDP_IS_AVAILABLE = True
|
68 |
# from torch.backends.cuda import SDPBackend, sdp_kernel
|
69 |
-
from torch.nn.attention import sdpa_kernel, SDPBackend
|
70 |
|
71 |
|
72 |
class Attention(nn.Module):
|
@@ -110,7 +110,7 @@ class Attention(nn.Module):
|
|
110 |
self.no_flash_op = no_flash_op
|
111 |
self.attn_mode = "torch"
|
112 |
|
113 |
-
self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch.
|
114 |
|
115 |
@staticmethod
|
116 |
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
@@ -198,9 +198,8 @@ class Attention(nn.Module):
|
|
198 |
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
199 |
q, k = self.q_norm(q), self.k_norm(k)
|
200 |
|
201 |
-
with sdpa_kernel([self.backend]): # new signature
|
202 |
-
|
203 |
-
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
204 |
|
205 |
del q, k, v
|
206 |
x = rearrange(x, "B H L D -> B L (H D)")
|
|
|
66 |
assert version.parse(torch.__version__) >= version.parse("2.0.0")
|
67 |
SDP_IS_AVAILABLE = True
|
68 |
# from torch.backends.cuda import SDPBackend, sdp_kernel
|
69 |
+
# from torch.nn.attention import sdpa_kernel, SDPBackend
|
70 |
|
71 |
|
72 |
class Attention(nn.Module):
|
|
|
110 |
self.no_flash_op = no_flash_op
|
111 |
self.attn_mode = "torch"
|
112 |
|
113 |
+
# self.backend = SDPBackend.FLASH_ATTENTION # FA implemented by torch.
|
114 |
|
115 |
@staticmethod
|
116 |
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
|
198 |
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
199 |
q, k = self.q_norm(q), self.k_norm(k)
|
200 |
|
201 |
+
# with sdpa_kernel([self.backend]): # new signature
|
202 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
|
203 |
|
204 |
del q, k, v
|
205 |
x = rearrange(x, "B H L D -> B L (H D)")
|