asigalov61 commited on
Commit
c10b216
1 Parent(s): 435839f

Upload x_transformer_1_23_2.py

Browse files
Files changed (1) hide show
  1. x_transformer_1_23_2.py +11 -4
x_transformer_1_23_2.py CHANGED
@@ -268,7 +268,8 @@ class Attend(nn.Module):
268
  # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
269
 
270
  # PyTorch 2.3-2.4 SDPA backend code...
271
- with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
 
272
 
273
  # New PyTorch 2.5 SDPA backend code:
274
  # with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
@@ -501,7 +502,8 @@ class AutoregressiveWrapper(Module):
501
  ignore_index = -100,
502
  pad_value = 0,
503
  mask_prob = 0.,
504
- add_attn_z_loss = False
 
505
  ):
506
  super().__init__()
507
  self.pad_value = pad_value
@@ -516,6 +518,7 @@ class AutoregressiveWrapper(Module):
516
 
517
  # whether to add router z-loss
518
  self.add_attn_z_loss = add_attn_z_loss
 
519
 
520
  @torch.inference_mode()
521
  @eval_decorator
@@ -709,8 +712,12 @@ class AutoregressiveWrapper(Module):
709
 
710
  if add_attn_z_loss:
711
  loss = loss + cache.attn_z_loss
712
-
713
- return loss, acc
 
 
 
 
714
 
715
  #===============================================================================
716
 
 
268
  # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
269
 
270
  # PyTorch 2.3-2.4 SDPA backend code...
271
+ # with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
272
+ with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
273
 
274
  # New PyTorch 2.5 SDPA backend code:
275
  # with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
 
502
  ignore_index = -100,
503
  pad_value = 0,
504
  mask_prob = 0.,
505
+ add_attn_z_loss = False,
506
+ return_cache=False
507
  ):
508
  super().__init__()
509
  self.pad_value = pad_value
 
518
 
519
  # whether to add router z-loss
520
  self.add_attn_z_loss = add_attn_z_loss
521
+ self.return_cache = return_cache
522
 
523
  @torch.inference_mode()
524
  @eval_decorator
 
712
 
713
  if add_attn_z_loss:
714
  loss = loss + cache.attn_z_loss
715
+
716
+ if self.return_cache:
717
+ return loss, acc, cache
718
+
719
+ else:
720
+ return loss, acc
721
 
722
  #===============================================================================
723