Add SDPA attention
#2
by
Katsumata420
- opened
- modeling_retrieva_bert.py +173 -19
modeling_retrieva_bert.py
CHANGED
@@ -34,6 +34,7 @@ from typing import Optional, Tuple, Union
|
|
34 |
|
35 |
import torch
|
36 |
import torch.utils.checkpoint
|
|
|
37 |
from torch import nn
|
38 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
39 |
|
@@ -49,6 +50,10 @@ from transformers.modeling_outputs import (
|
|
49 |
SequenceClassifierOutput,
|
50 |
TokenClassifierOutput,
|
51 |
)
|
|
|
|
|
|
|
|
|
52 |
from transformers.modeling_utils import PreTrainedModel
|
53 |
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
54 |
from transformers.utils import (
|
@@ -56,6 +61,7 @@ from transformers.utils import (
|
|
56 |
add_code_sample_docstrings,
|
57 |
add_start_docstrings,
|
58 |
add_start_docstrings_to_model_forward,
|
|
|
59 |
logging,
|
60 |
replace_return_docstrings,
|
61 |
)
|
@@ -407,6 +413,113 @@ class RetrievaBertSelfAttention(nn.Module):
|
|
407 |
return outputs
|
408 |
|
409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
# Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RetrievaBertAttention below.
|
411 |
class RetrievaBertSelfOutput(nn.Module):
|
412 |
def __init__(self, config):
|
@@ -420,12 +533,18 @@ class RetrievaBertSelfOutput(nn.Module):
|
|
420 |
return residual + hidden_states
|
421 |
|
422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
# Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
|
424 |
class RetrievaBertAttention(nn.Module):
|
425 |
def __init__(self, config):
|
426 |
super().__init__()
|
427 |
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
428 |
-
self.self =
|
429 |
self.output = RetrievaBertSelfOutput(config)
|
430 |
self.pruned_heads = set()
|
431 |
|
@@ -808,6 +927,7 @@ class RetrievaBertPreTrainedModel(PreTrainedModel):
|
|
808 |
load_tf_weights = load_tf_weights_in_megatron_bert
|
809 |
base_model_prefix = "bert"
|
810 |
supports_gradient_checkpointing = True
|
|
|
811 |
|
812 |
def _init_weights(self, module):
|
813 |
"""Initialize the weights"""
|
@@ -953,6 +1073,8 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
|
|
953 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
954 |
)
|
955 |
|
|
|
|
|
956 |
# Initialize weights and apply final processing
|
957 |
self.post_init()
|
958 |
|
@@ -1046,9 +1168,48 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
|
|
1046 |
if position_ids is None:
|
1047 |
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
1048 |
|
1049 |
-
|
1050 |
-
|
1051 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1052 |
|
1053 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
1054 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
@@ -1057,24 +1218,17 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
|
|
1057 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
1058 |
if encoder_attention_mask is None:
|
1059 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
1060 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1061 |
else:
|
1062 |
encoder_extended_attention_mask = None
|
1063 |
|
1064 |
-
# Prepare head mask if needed
|
1065 |
-
# 1.0 in head_mask indicate we keep the head
|
1066 |
-
# attention_probs has shape bsz x n_heads x N x N
|
1067 |
-
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1068 |
-
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1069 |
-
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1070 |
-
|
1071 |
-
embedding_output = self.embeddings(
|
1072 |
-
input_ids=input_ids,
|
1073 |
-
position_ids=position_ids,
|
1074 |
-
token_type_ids=token_type_ids,
|
1075 |
-
inputs_embeds=inputs_embeds,
|
1076 |
-
past_key_values_length=past_key_values_length,
|
1077 |
-
)
|
1078 |
encoder_outputs = self.encoder(
|
1079 |
embedding_output,
|
1080 |
attention_mask=extended_attention_mask,
|
|
|
34 |
|
35 |
import torch
|
36 |
import torch.utils.checkpoint
|
37 |
+
from packaging import version
|
38 |
from torch import nn
|
39 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
40 |
|
|
|
50 |
SequenceClassifierOutput,
|
51 |
TokenClassifierOutput,
|
52 |
)
|
53 |
+
from transformers.modeling_attn_mask_utils import (
|
54 |
+
_prepare_4d_attention_mask_for_sdpa,
|
55 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
56 |
+
)
|
57 |
from transformers.modeling_utils import PreTrainedModel
|
58 |
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
59 |
from transformers.utils import (
|
|
|
61 |
add_code_sample_docstrings,
|
62 |
add_start_docstrings,
|
63 |
add_start_docstrings_to_model_forward,
|
64 |
+
get_torch_version,
|
65 |
logging,
|
66 |
replace_return_docstrings,
|
67 |
)
|
|
|
413 |
return outputs
|
414 |
|
415 |
|
416 |
+
class RetrievaBertSdpaSelfAttention(RetrievaBertSelfAttention):
|
417 |
+
def __init__(self, config):
|
418 |
+
super().__init__(config)
|
419 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
420 |
+
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
|
421 |
+
|
422 |
+
def forward(
|
423 |
+
self,
|
424 |
+
hidden_states: torch.Tensor,
|
425 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
426 |
+
position_ids: Optional[torch.LongTensor] = None,
|
427 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
428 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
429 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
430 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
431 |
+
output_attentions: Optional[bool] = False,
|
432 |
+
) -> Tuple[torch.Tensor]:
|
433 |
+
if output_attentions or head_mask is not None:
|
434 |
+
logger.warning_once(
|
435 |
+
"RetrievaBertSdpaSelfAttention is used but `torch.nn.fuctional.scaled_dot_product_attention` does not support "
|
436 |
+
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation. "
|
437 |
+
)
|
438 |
+
return super().forward(
|
439 |
+
hidden_states,
|
440 |
+
attention_mask,
|
441 |
+
position_ids,
|
442 |
+
head_mask,
|
443 |
+
encoder_hidden_states,
|
444 |
+
encoder_attention_mask,
|
445 |
+
past_key_value,
|
446 |
+
output_attentions,
|
447 |
+
)
|
448 |
+
|
449 |
+
bsz, tgt_len, _ = hidden_states.size()
|
450 |
+
|
451 |
+
mixed_query_layer = self.query(hidden_states)
|
452 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, is_query=True)
|
453 |
+
|
454 |
+
# If this is instantiated as a cross-attention module, the keys
|
455 |
+
# and values come from an encoder; the attention mask needs to be
|
456 |
+
# such that the encoder's padding tokens are not attended to.
|
457 |
+
is_cross_attention = encoder_hidden_states is not None
|
458 |
+
|
459 |
+
# The following code is based on the implementation of `transformers.BertSdpaSelfAttention`
|
460 |
+
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
461 |
+
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
462 |
+
|
463 |
+
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
464 |
+
key_layer, value_layer = past_key_value
|
465 |
+
else:
|
466 |
+
key_layer = self.transpose_for_scores(self.key(current_states), is_query=False)
|
467 |
+
value_layer = self.transpose_for_scores(self.value(current_states), is_query=False)
|
468 |
+
|
469 |
+
if self.rope_emb is not None:
|
470 |
+
cos, sin = self.rope_emb(hidden_states, position_ids)
|
471 |
+
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
472 |
+
|
473 |
+
if past_key_value is not None and not is_cross_attention:
|
474 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
475 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
476 |
+
|
477 |
+
# For GQA, we repeat the key/value weights.
|
478 |
+
key_layer = repeat_kv(key_layer, self.num_key_value_groups)
|
479 |
+
value_layer = repeat_kv(value_layer, self.num_key_value_groups)
|
480 |
+
|
481 |
+
if self.is_decoder:
|
482 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
483 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
484 |
+
# key/value_states (first "if" case)
|
485 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
486 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
487 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
488 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
489 |
+
past_key_value = (key_layer, value_layer)
|
490 |
+
|
491 |
+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
492 |
+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
493 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
494 |
+
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
495 |
+
query_layer = query_layer.contiguous()
|
496 |
+
key_layer = key_layer.contiguous()
|
497 |
+
value_layer = value_layer.contiguous()
|
498 |
+
|
499 |
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
500 |
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
501 |
+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
502 |
+
# a causal mask in case tgt_len == 1.
|
503 |
+
is_causal = (
|
504 |
+
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
505 |
+
)
|
506 |
+
|
507 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
508 |
+
query_layer,
|
509 |
+
key_layer,
|
510 |
+
value_layer,
|
511 |
+
attn_mask=attention_mask,
|
512 |
+
is_causal=is_causal,
|
513 |
+
dropout_p=self.dropout_prob if self.training else 0.0,
|
514 |
+
)
|
515 |
+
attn_output = attn_output.transpose(1, 2)
|
516 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
517 |
+
|
518 |
+
outputs = (attn_output,)
|
519 |
+
if self.is_decoder:
|
520 |
+
outputs = outputs + (past_key_value,)
|
521 |
+
return outputs
|
522 |
+
|
523 |
# Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RetrievaBertAttention below.
|
524 |
class RetrievaBertSelfOutput(nn.Module):
|
525 |
def __init__(self, config):
|
|
|
533 |
return residual + hidden_states
|
534 |
|
535 |
|
536 |
+
RETRIEVA_BERT_SELF_ATTENTION_CLASSES = {
|
537 |
+
"eager": RetrievaBertSelfAttention,
|
538 |
+
"sdpa": RetrievaBertSdpaSelfAttention,
|
539 |
+
}
|
540 |
+
|
541 |
+
|
542 |
# Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
|
543 |
class RetrievaBertAttention(nn.Module):
|
544 |
def __init__(self, config):
|
545 |
super().__init__()
|
546 |
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
547 |
+
self.self = RETRIEVA_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](config)
|
548 |
self.output = RetrievaBertSelfOutput(config)
|
549 |
self.pruned_heads = set()
|
550 |
|
|
|
927 |
load_tf_weights = load_tf_weights_in_megatron_bert
|
928 |
base_model_prefix = "bert"
|
929 |
supports_gradient_checkpointing = True
|
930 |
+
_supports_sdpa = True
|
931 |
|
932 |
def _init_weights(self, module):
|
933 |
"""Initialize the weights"""
|
|
|
1073 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
1074 |
)
|
1075 |
|
1076 |
+
self.attn_implementation = config._attn_implementation
|
1077 |
+
|
1078 |
# Initialize weights and apply final processing
|
1079 |
self.post_init()
|
1080 |
|
|
|
1168 |
if position_ids is None:
|
1169 |
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
1170 |
|
1171 |
+
embedding_output = self.embeddings(
|
1172 |
+
input_ids=input_ids,
|
1173 |
+
position_ids=position_ids,
|
1174 |
+
token_type_ids=token_type_ids,
|
1175 |
+
inputs_embeds=inputs_embeds,
|
1176 |
+
past_key_values_length=past_key_values_length,
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
# Prepare head mask if needed
|
1180 |
+
# 1.0 in head_mask indicate we keep the head
|
1181 |
+
# attention_probs has shape bsz x n_heads x N x N
|
1182 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1183 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1184 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1185 |
+
|
1186 |
+
use_sdpa_attention_masks = (
|
1187 |
+
self.attn_implementation == "adpa"
|
1188 |
+
and head_mask is None
|
1189 |
+
and not output_attentions
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
extended_attention_mask: torch.Tensor
|
1193 |
+
if use_sdpa_attention_masks:
|
1194 |
+
# Expand the attention mask for SDPA.
|
1195 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
1196 |
+
if self.config.is_decoder:
|
1197 |
+
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1198 |
+
attention_mask,
|
1199 |
+
input_shape,
|
1200 |
+
embedding_output,
|
1201 |
+
past_key_values_length,
|
1202 |
+
)
|
1203 |
+
else:
|
1204 |
+
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
1205 |
+
attention_mask,
|
1206 |
+
embedding_output.dtype,
|
1207 |
+
tgt_len=seq_length,
|
1208 |
+
)
|
1209 |
+
else:
|
1210 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
1211 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
1212 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
1213 |
|
1214 |
# If a 2D or 3D attention mask is provided for the cross-attention
|
1215 |
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
|
1218 |
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
1219 |
if encoder_attention_mask is None:
|
1220 |
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
1221 |
+
if use_sdpa_attention_masks:
|
1222 |
+
# Expand the attention mask for SDPA.
|
1223 |
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
1224 |
+
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
1225 |
+
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
|
1226 |
+
)
|
1227 |
+
else:
|
1228 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
1229 |
else:
|
1230 |
encoder_extended_attention_mask = None
|
1231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1232 |
encoder_outputs = self.encoder(
|
1233 |
embedding_output,
|
1234 |
attention_mask=extended_attention_mask,
|