Fix CPU Fallback for NewAttention with xformers BlockDiagonalMask
Browse filesescription:
This PR addresses two critical issues when running the model on CPU:
1. Memory Efficient Attention CPU Fallback:
- Added CPU device detection in NewAttention initialization
- Automatically disables xformers memory efficient attention when running on CPU
- Prevents NotImplementedError from xformers which only supports CUDA devices
- Added proper handling of xformers BlockDiagonalMask in standard attention
- Materializes BlockDiagonalMask to tensor before addition with attention scores
The fix ensures smooth fallback to standard attention mechanism when running on CPU while maintaining compatibility with xformers mask types.
Related issue:
NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
query : shape=(1, 5471, 16, 64) (torch.float32)
key : shape=(1, 5471, 16, 64) (torch.float32)
value : shape=(1, 5471, 16, 64) (torch.float32)
attn_bias : <class 'xformers.ops.fmha.attn_bias.BlockDiagonalMask'>
p : 0.0
`fa2F@v2.5.7-pt` is not supported because:
device=cpu (supported: {'cuda'})
dtype=torch.float32 (supported: {torch.float16, torch.bfloat16})
`cutlassF-pt` is not supported because:
device=cpu (supported: {'cuda'})
- modeling.py +13 -0
@@ -445,6 +445,10 @@ class NewAttention(nn.Module):
|
|
445 |
|
446 |
if use_memory_efficient_attention is None:
|
447 |
use_memory_efficient_attention = self.config.use_memory_efficient_attention
|
|
|
|
|
|
|
|
|
448 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
449 |
self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
|
450 |
if self.use_memory_efficient_attention:
|
@@ -489,6 +493,9 @@ class NewAttention(nn.Module):
|
|
489 |
key_states = pad_input(key_states.squeeze(), *padding_inputs)
|
490 |
value_states = pad_input(value_states.squeeze(), *padding_inputs)
|
491 |
|
|
|
|
|
|
|
492 |
if self.use_memory_efficient_attention:
|
493 |
assert self.memory_efficient_attention is not None, "xformers is not loaded"
|
494 |
assert output_attentions is False, "memory_efficient_attention do not output attentions"
|
@@ -534,6 +541,12 @@ class NewAttention(nn.Module):
|
|
534 |
|
535 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
536 |
if attention_bias is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
538 |
attention_scores = attention_scores + attention_bias
|
539 |
|
|
|
445 |
|
446 |
if use_memory_efficient_attention is None:
|
447 |
use_memory_efficient_attention = self.config.use_memory_efficient_attention
|
448 |
+
|
449 |
+
if not torch.cuda.is_available() or (hasattr(config, 'device') and config.device == 'cpu'):
|
450 |
+
use_memory_efficient_attention = False
|
451 |
+
|
452 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
453 |
self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
|
454 |
if self.use_memory_efficient_attention:
|
|
|
493 |
key_states = pad_input(key_states.squeeze(), *padding_inputs)
|
494 |
value_states = pad_input(value_states.squeeze(), *padding_inputs)
|
495 |
|
496 |
+
if self.use_memory_efficient_attention and not hidden_states.is_cuda:
|
497 |
+
self.use_memory_efficient_attention = False
|
498 |
+
|
499 |
if self.use_memory_efficient_attention:
|
500 |
assert self.memory_efficient_attention is not None, "xformers is not loaded"
|
501 |
assert output_attentions is False, "memory_efficient_attention do not output attentions"
|
|
|
541 |
|
542 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
543 |
if attention_bias is not None:
|
544 |
+
if hasattr(attention_bias, 'materialize'):
|
545 |
+
# If it's a BlockDiagonalMask, materialize it to a tensor
|
546 |
+
attention_bias = attention_bias.materialize(
|
547 |
+
(attention_scores.shape[0], attention_scores.shape[1],
|
548 |
+
attention_scores.shape[2], attention_scores.shape[3])
|
549 |
+
)
|
550 |
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
551 |
attention_scores = attention_scores + attention_bias
|
552 |
|