Commit
•
e54ca32
1
Parent(s):
0f23f59
Update modeling_florence2.py
Browse files- modeling_florence2.py +46 -32
modeling_florence2.py
CHANGED
@@ -33,11 +33,8 @@ from transformers.utils import (
|
|
33 |
ModelOutput,
|
34 |
add_start_docstrings,
|
35 |
add_start_docstrings_to_model_forward,
|
36 |
-
is_flash_attn_2_available,
|
37 |
logging,
|
38 |
replace_return_docstrings,
|
39 |
-
is_flash_attn_2_available,
|
40 |
-
is_flash_attn_greater_or_equal_2_10,
|
41 |
)
|
42 |
from .configuration_florence2 import Florence2Config
|
43 |
from .configuration_florence2 import Florence2LanguageConfig
|
@@ -58,9 +55,52 @@ from transformers.modeling_outputs import (
|
|
58 |
Seq2SeqModelOutput,
|
59 |
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
if is_flash_attn_2_available():
|
63 |
-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
64 |
|
65 |
logger = logging.get_logger(__name__)
|
66 |
|
@@ -1049,36 +1089,10 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
1049 |
softmax_scale (`float`, *optional*):
|
1050 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
1051 |
"""
|
1052 |
-
if not self._flash_attn_uses_top_left_mask:
|
1053 |
-
causal = self.is_causal
|
1054 |
-
else:
|
1055 |
-
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
1056 |
-
causal = self.is_causal and query_length != 1
|
1057 |
|
1058 |
# Contains at least one padding token in the sequence
|
1059 |
if attention_mask is not None:
|
1060 |
-
|
1061 |
-
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
1062 |
-
query_states, key_states, value_states, attention_mask, query_length
|
1063 |
-
)
|
1064 |
-
|
1065 |
-
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
1066 |
-
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
1067 |
-
|
1068 |
-
attn_output_unpad = flash_attn_varlen_func(
|
1069 |
-
query_states,
|
1070 |
-
key_states,
|
1071 |
-
value_states,
|
1072 |
-
cu_seqlens_q=cu_seqlens_q,
|
1073 |
-
cu_seqlens_k=cu_seqlens_k,
|
1074 |
-
max_seqlen_q=max_seqlen_in_batch_q,
|
1075 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
1076 |
-
dropout_p=dropout,
|
1077 |
-
softmax_scale=softmax_scale,
|
1078 |
-
causal=causal,
|
1079 |
-
)
|
1080 |
-
|
1081 |
-
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
1082 |
else:
|
1083 |
attn_output = flash_attn_func(
|
1084 |
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
|
|
33 |
ModelOutput,
|
34 |
add_start_docstrings,
|
35 |
add_start_docstrings_to_model_forward,
|
|
|
36 |
logging,
|
37 |
replace_return_docstrings,
|
|
|
|
|
38 |
)
|
39 |
from .configuration_florence2 import Florence2Config
|
40 |
from .configuration_florence2 import Florence2LanguageConfig
|
|
|
55 |
Seq2SeqModelOutput,
|
56 |
)
|
57 |
|
58 |
+
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
|
59 |
+
# Standard scaled dot-product attention
|
60 |
+
d_k = q.size(-1)
|
61 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=q.dtype))
|
62 |
+
|
63 |
+
if causal:
|
64 |
+
mask = torch.triu(torch.ones_like(scores), diagonal=1)
|
65 |
+
scores = scores.masked_fill(mask.bool(), float('-inf'))
|
66 |
+
|
67 |
+
attn = F.softmax(scores, dim=-1)
|
68 |
+
if dropout_p > 0:
|
69 |
+
attn = F.dropout(attn, p=dropout_p)
|
70 |
+
|
71 |
+
return torch.matmul(attn, v)
|
72 |
+
|
73 |
+
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False):
|
74 |
+
# For simplicity, we'll just call the non-varlen version
|
75 |
+
return flash_attn_func(q, k, v, dropout_p, softmax_scale, causal)
|
76 |
+
|
77 |
+
# Dummy classes to mimic flash_attn.bert_padding
|
78 |
+
class DummyIndexFirstAxis:
|
79 |
+
@staticmethod
|
80 |
+
def __call__(x, index):
|
81 |
+
return x[index]
|
82 |
+
|
83 |
+
class DummyPadInput:
|
84 |
+
@staticmethod
|
85 |
+
def __call__(x, indices, batch_size, seqlen):
|
86 |
+
return x
|
87 |
+
|
88 |
+
class DummyUnpadInput:
|
89 |
+
@staticmethod
|
90 |
+
def __call__(x, indices):
|
91 |
+
return x, indices, x.shape[1]
|
92 |
+
|
93 |
+
index_first_axis = DummyIndexFirstAxis()
|
94 |
+
pad_input = DummyPadInput()
|
95 |
+
unpad_input = DummyUnpadInput()
|
96 |
+
|
97 |
+
def is_flash_attn_2_available():
|
98 |
+
return True # Always return True
|
99 |
+
|
100 |
+
# Replace the is_flash_attn_greater_or_equal_2_10 function
|
101 |
+
def is_flash_attn_greater_or_equal_2_10():
|
102 |
+
return True # Always return True
|
103 |
|
|
|
|
|
104 |
|
105 |
logger = logging.get_logger(__name__)
|
106 |
|
|
|
1089 |
softmax_scale (`float`, *optional*):
|
1090 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
1091 |
"""
|
|
|
|
|
|
|
|
|
|
|
1092 |
|
1093 |
# Contains at least one padding token in the sequence
|
1094 |
if attention_mask is not None:
|
1095 |
+
return super()._flash_attn_forward(query_states, key_states, value_states, attention_mask, query_length, dropout, softmax_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1096 |
else:
|
1097 |
attn_output = flash_attn_func(
|
1098 |
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|