Fix CPU Fallback for NewAttention with xformers BlockDiagonalMask

#25
by m7mdhka - opened
Files changed (1) hide show
  1. modeling.py +13 -0
modeling.py CHANGED
@@ -445,6 +445,10 @@ class NewAttention(nn.Module):
445
 
446
  if use_memory_efficient_attention is None:
447
  use_memory_efficient_attention = self.config.use_memory_efficient_attention
 
 
 
 
448
  self.use_memory_efficient_attention = use_memory_efficient_attention
449
  self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
450
  if self.use_memory_efficient_attention:
@@ -489,6 +493,9 @@ class NewAttention(nn.Module):
489
  key_states = pad_input(key_states.squeeze(), *padding_inputs)
490
  value_states = pad_input(value_states.squeeze(), *padding_inputs)
491
 
 
 
 
492
  if self.use_memory_efficient_attention:
493
  assert self.memory_efficient_attention is not None, "xformers is not loaded"
494
  assert output_attentions is False, "memory_efficient_attention do not output attentions"
@@ -534,6 +541,12 @@ class NewAttention(nn.Module):
534
 
535
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
536
  if attention_bias is not None:
 
 
 
 
 
 
537
  # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
538
  attention_scores = attention_scores + attention_bias
539
 
 
445
 
446
  if use_memory_efficient_attention is None:
447
  use_memory_efficient_attention = self.config.use_memory_efficient_attention
448
+
449
+ if not torch.cuda.is_available() or (hasattr(config, 'device') and config.device == 'cpu'):
450
+ use_memory_efficient_attention = False
451
+
452
  self.use_memory_efficient_attention = use_memory_efficient_attention
453
  self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
454
  if self.use_memory_efficient_attention:
 
493
  key_states = pad_input(key_states.squeeze(), *padding_inputs)
494
  value_states = pad_input(value_states.squeeze(), *padding_inputs)
495
 
496
+ if self.use_memory_efficient_attention and not hidden_states.is_cuda:
497
+ self.use_memory_efficient_attention = False
498
+
499
  if self.use_memory_efficient_attention:
500
  assert self.memory_efficient_attention is not None, "xformers is not loaded"
501
  assert output_attentions is False, "memory_efficient_attention do not output attentions"
 
541
 
542
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
543
  if attention_bias is not None:
544
+ if hasattr(attention_bias, 'materialize'):
545
+ # If it's a BlockDiagonalMask, materialize it to a tensor
546
+ attention_bias = attention_bias.materialize(
547
+ (attention_scores.shape[0], attention_scores.shape[1],
548
+ attention_scores.shape[2], attention_scores.shape[3])
549
+ )
550
  # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
551
  attention_scores = attention_scores + attention_bias
552