Spaces:
Running
on
Zero
Running
on
Zero
asigalov61
commited on
Commit
•
c10b216
1
Parent(s):
435839f
Upload x_transformer_1_23_2.py
Browse files- x_transformer_1_23_2.py +11 -4
x_transformer_1_23_2.py
CHANGED
@@ -268,7 +268,8 @@ class Attend(nn.Module):
|
|
268 |
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
269 |
|
270 |
# PyTorch 2.3-2.4 SDPA backend code...
|
271 |
-
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
|
|
272 |
|
273 |
# New PyTorch 2.5 SDPA backend code:
|
274 |
# with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
@@ -501,7 +502,8 @@ class AutoregressiveWrapper(Module):
|
|
501 |
ignore_index = -100,
|
502 |
pad_value = 0,
|
503 |
mask_prob = 0.,
|
504 |
-
add_attn_z_loss = False
|
|
|
505 |
):
|
506 |
super().__init__()
|
507 |
self.pad_value = pad_value
|
@@ -516,6 +518,7 @@ class AutoregressiveWrapper(Module):
|
|
516 |
|
517 |
# whether to add router z-loss
|
518 |
self.add_attn_z_loss = add_attn_z_loss
|
|
|
519 |
|
520 |
@torch.inference_mode()
|
521 |
@eval_decorator
|
@@ -709,8 +712,12 @@ class AutoregressiveWrapper(Module):
|
|
709 |
|
710 |
if add_attn_z_loss:
|
711 |
loss = loss + cache.attn_z_loss
|
712 |
-
|
713 |
-
|
|
|
|
|
|
|
|
|
714 |
|
715 |
#===============================================================================
|
716 |
|
|
|
268 |
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
269 |
|
270 |
# PyTorch 2.3-2.4 SDPA backend code...
|
271 |
+
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
272 |
+
with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
|
273 |
|
274 |
# New PyTorch 2.5 SDPA backend code:
|
275 |
# with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
|
|
502 |
ignore_index = -100,
|
503 |
pad_value = 0,
|
504 |
mask_prob = 0.,
|
505 |
+
add_attn_z_loss = False,
|
506 |
+
return_cache=False
|
507 |
):
|
508 |
super().__init__()
|
509 |
self.pad_value = pad_value
|
|
|
518 |
|
519 |
# whether to add router z-loss
|
520 |
self.add_attn_z_loss = add_attn_z_loss
|
521 |
+
self.return_cache = return_cache
|
522 |
|
523 |
@torch.inference_mode()
|
524 |
@eval_decorator
|
|
|
712 |
|
713 |
if add_attn_z_loss:
|
714 |
loss = loss + cache.attn_z_loss
|
715 |
+
|
716 |
+
if self.return_cache:
|
717 |
+
return loss, acc, cache
|
718 |
+
|
719 |
+
else:
|
720 |
+
return loss, acc
|
721 |
|
722 |
#===============================================================================
|
723 |
|