m7mdhka commited on
Commit
b62a12d
1 Parent(s): 24e2e1f

Fix CPU Fallback for NewAttention with xformers BlockDiagonalMask

Browse files

escription:
This PR addresses two critical issues when running the model on CPU:

1. Memory Efficient Attention CPU Fallback:
- Added CPU device detection in NewAttention initialization
- Automatically disables xformers memory efficient attention when running on CPU
- Prevents NotImplementedError from xformers which only supports CUDA devices
- Added proper handling of xformers BlockDiagonalMask in standard attention
- Materializes BlockDiagonalMask to tensor before addition with attention scores

The fix ensures smooth fallback to standard attention mechanism when running on CPU while maintaining compatibility with xformers mask types.

Related issue:
NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
query : shape=(1, 5471, 16, 64) (torch.float32)
key : shape=(1, 5471, 16, 64) (torch.float32)
value : shape=(1, 5471, 16, 64) (torch.float32)
attn_bias : <class 'xformers.ops.fmha.attn_bias.BlockDiagonalMask'>
p : 0.0
`fa2F@v2.5.7-pt` is not supported because:
device=cpu (supported: {'cuda'})
dtype=torch.float32 (supported: {torch.float16, torch.bfloat16})
`cutlassF-pt` is not supported because:
device=cpu (supported: {'cuda'})

Files changed (1) hide show
  1. modeling.py +13 -0
modeling.py CHANGED
@@ -445,6 +445,10 @@ class NewAttention(nn.Module):
445
 
446
  if use_memory_efficient_attention is None:
447
  use_memory_efficient_attention = self.config.use_memory_efficient_attention
 
 
 
 
448
  self.use_memory_efficient_attention = use_memory_efficient_attention
449
  self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
450
  if self.use_memory_efficient_attention:
@@ -489,6 +493,9 @@ class NewAttention(nn.Module):
489
  key_states = pad_input(key_states.squeeze(), *padding_inputs)
490
  value_states = pad_input(value_states.squeeze(), *padding_inputs)
491
 
 
 
 
492
  if self.use_memory_efficient_attention:
493
  assert self.memory_efficient_attention is not None, "xformers is not loaded"
494
  assert output_attentions is False, "memory_efficient_attention do not output attentions"
@@ -534,6 +541,12 @@ class NewAttention(nn.Module):
534
 
535
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
536
  if attention_bias is not None:
 
 
 
 
 
 
537
  # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
538
  attention_scores = attention_scores + attention_bias
539
 
 
445
 
446
  if use_memory_efficient_attention is None:
447
  use_memory_efficient_attention = self.config.use_memory_efficient_attention
448
+
449
+ if not torch.cuda.is_available() or (hasattr(config, 'device') and config.device == 'cpu'):
450
+ use_memory_efficient_attention = False
451
+
452
  self.use_memory_efficient_attention = use_memory_efficient_attention
453
  self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
454
  if self.use_memory_efficient_attention:
 
493
  key_states = pad_input(key_states.squeeze(), *padding_inputs)
494
  value_states = pad_input(value_states.squeeze(), *padding_inputs)
495
 
496
+ if self.use_memory_efficient_attention and not hidden_states.is_cuda:
497
+ self.use_memory_efficient_attention = False
498
+
499
  if self.use_memory_efficient_attention:
500
  assert self.memory_efficient_attention is not None, "xformers is not loaded"
501
  assert output_attentions is False, "memory_efficient_attention do not output attentions"
 
541
 
542
  attention_scores = attention_scores / math.sqrt(self.attention_head_size)
543
  if attention_bias is not None:
544
+ if hasattr(attention_bias, 'materialize'):
545
+ # If it's a BlockDiagonalMask, materialize it to a tensor
546
+ attention_bias = attention_bias.materialize(
547
+ (attention_scores.shape[0], attention_scores.shape[1],
548
+ attention_scores.shape[2], attention_scores.shape[3])
549
+ )
550
  # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
551
  attention_scores = attention_scores + attention_bias
552