multimodalart HF staff commited on
Commit
e54ca32
1 Parent(s): 0f23f59

Update modeling_florence2.py

Browse files
Files changed (1) hide show
  1. modeling_florence2.py +46 -32
modeling_florence2.py CHANGED
@@ -33,11 +33,8 @@ from transformers.utils import (
33
  ModelOutput,
34
  add_start_docstrings,
35
  add_start_docstrings_to_model_forward,
36
- is_flash_attn_2_available,
37
  logging,
38
  replace_return_docstrings,
39
- is_flash_attn_2_available,
40
- is_flash_attn_greater_or_equal_2_10,
41
  )
42
  from .configuration_florence2 import Florence2Config
43
  from .configuration_florence2 import Florence2LanguageConfig
@@ -58,9 +55,52 @@ from transformers.modeling_outputs import (
58
  Seq2SeqModelOutput,
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if is_flash_attn_2_available():
63
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
64
 
65
  logger = logging.get_logger(__name__)
66
 
@@ -1049,36 +1089,10 @@ class Florence2FlashAttention2(Florence2Attention):
1049
  softmax_scale (`float`, *optional*):
1050
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1051
  """
1052
- if not self._flash_attn_uses_top_left_mask:
1053
- causal = self.is_causal
1054
- else:
1055
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
1056
- causal = self.is_causal and query_length != 1
1057
 
1058
  # Contains at least one padding token in the sequence
1059
  if attention_mask is not None:
1060
- batch_size = query_states.shape[0]
1061
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1062
- query_states, key_states, value_states, attention_mask, query_length
1063
- )
1064
-
1065
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1066
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1067
-
1068
- attn_output_unpad = flash_attn_varlen_func(
1069
- query_states,
1070
- key_states,
1071
- value_states,
1072
- cu_seqlens_q=cu_seqlens_q,
1073
- cu_seqlens_k=cu_seqlens_k,
1074
- max_seqlen_q=max_seqlen_in_batch_q,
1075
- max_seqlen_k=max_seqlen_in_batch_k,
1076
- dropout_p=dropout,
1077
- softmax_scale=softmax_scale,
1078
- causal=causal,
1079
- )
1080
-
1081
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1082
  else:
1083
  attn_output = flash_attn_func(
1084
  query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
 
33
  ModelOutput,
34
  add_start_docstrings,
35
  add_start_docstrings_to_model_forward,
 
36
  logging,
37
  replace_return_docstrings,
 
 
38
  )
39
  from .configuration_florence2 import Florence2Config
40
  from .configuration_florence2 import Florence2LanguageConfig
 
55
  Seq2SeqModelOutput,
56
  )
57
 
58
+ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
59
+ # Standard scaled dot-product attention
60
+ d_k = q.size(-1)
61
+ scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=q.dtype))
62
+
63
+ if causal:
64
+ mask = torch.triu(torch.ones_like(scores), diagonal=1)
65
+ scores = scores.masked_fill(mask.bool(), float('-inf'))
66
+
67
+ attn = F.softmax(scores, dim=-1)
68
+ if dropout_p > 0:
69
+ attn = F.dropout(attn, p=dropout_p)
70
+
71
+ return torch.matmul(attn, v)
72
+
73
+ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False):
74
+ # For simplicity, we'll just call the non-varlen version
75
+ return flash_attn_func(q, k, v, dropout_p, softmax_scale, causal)
76
+
77
+ # Dummy classes to mimic flash_attn.bert_padding
78
+ class DummyIndexFirstAxis:
79
+ @staticmethod
80
+ def __call__(x, index):
81
+ return x[index]
82
+
83
+ class DummyPadInput:
84
+ @staticmethod
85
+ def __call__(x, indices, batch_size, seqlen):
86
+ return x
87
+
88
+ class DummyUnpadInput:
89
+ @staticmethod
90
+ def __call__(x, indices):
91
+ return x, indices, x.shape[1]
92
+
93
+ index_first_axis = DummyIndexFirstAxis()
94
+ pad_input = DummyPadInput()
95
+ unpad_input = DummyUnpadInput()
96
+
97
+ def is_flash_attn_2_available():
98
+ return True # Always return True
99
+
100
+ # Replace the is_flash_attn_greater_or_equal_2_10 function
101
+ def is_flash_attn_greater_or_equal_2_10():
102
+ return True # Always return True
103
 
 
 
104
 
105
  logger = logging.get_logger(__name__)
106
 
 
1089
  softmax_scale (`float`, *optional*):
1090
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1091
  """
 
 
 
 
 
1092
 
1093
  # Contains at least one padding token in the sequence
1094
  if attention_mask is not None:
1095
+ return super()._flash_attn_forward(query_states, key_states, value_states, attention_mask, query_length, dropout, softmax_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1096
  else:
1097
  attn_output = flash_attn_func(
1098
  query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal