Fix CPU Fallback for NewAttention with xformers BlockDiagonalMask
#25
by
m7mdhka
- opened
- modeling.py +13 -0
modeling.py
CHANGED
@@ -445,6 +445,10 @@ class NewAttention(nn.Module):
|
|
445 |
|
446 |
if use_memory_efficient_attention is None:
|
447 |
use_memory_efficient_attention = self.config.use_memory_efficient_attention
|
|
|
|
|
|
|
|
|
448 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
449 |
self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
|
450 |
if self.use_memory_efficient_attention:
|
@@ -489,6 +493,9 @@ class NewAttention(nn.Module):
|
|
489 |
key_states = pad_input(key_states.squeeze(), *padding_inputs)
|
490 |
value_states = pad_input(value_states.squeeze(), *padding_inputs)
|
491 |
|
|
|
|
|
|
|
492 |
if self.use_memory_efficient_attention:
|
493 |
assert self.memory_efficient_attention is not None, "xformers is not loaded"
|
494 |
assert output_attentions is False, "memory_efficient_attention do not output attentions"
|
@@ -534,6 +541,12 @@ class NewAttention(nn.Module):
|
|
534 |
|
535 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
536 |
if attention_bias is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
538 |
attention_scores = attention_scores + attention_bias
|
539 |
|
|
|
445 |
|
446 |
if use_memory_efficient_attention is None:
|
447 |
use_memory_efficient_attention = self.config.use_memory_efficient_attention
|
448 |
+
|
449 |
+
if not torch.cuda.is_available() or (hasattr(config, 'device') and config.device == 'cpu'):
|
450 |
+
use_memory_efficient_attention = False
|
451 |
+
|
452 |
self.use_memory_efficient_attention = use_memory_efficient_attention
|
453 |
self.memory_efficient_attention = None if xops is None else xops.memory_efficient_attention
|
454 |
if self.use_memory_efficient_attention:
|
|
|
493 |
key_states = pad_input(key_states.squeeze(), *padding_inputs)
|
494 |
value_states = pad_input(value_states.squeeze(), *padding_inputs)
|
495 |
|
496 |
+
if self.use_memory_efficient_attention and not hidden_states.is_cuda:
|
497 |
+
self.use_memory_efficient_attention = False
|
498 |
+
|
499 |
if self.use_memory_efficient_attention:
|
500 |
assert self.memory_efficient_attention is not None, "xformers is not loaded"
|
501 |
assert output_attentions is False, "memory_efficient_attention do not output attentions"
|
|
|
541 |
|
542 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
543 |
if attention_bias is not None:
|
544 |
+
if hasattr(attention_bias, 'materialize'):
|
545 |
+
# If it's a BlockDiagonalMask, materialize it to a tensor
|
546 |
+
attention_bias = attention_bias.materialize(
|
547 |
+
(attention_scores.shape[0], attention_scores.shape[1],
|
548 |
+
attention_scores.shape[2], attention_scores.shape[3])
|
549 |
+
)
|
550 |
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
551 |
attention_scores = attention_scores + attention_bias
|
552 |
|