x54-729
commited on
Commit
•
6e1fdc1
1
Parent(s):
bcad9ec
fix import error
Browse files- modeling_internlm.py +19 -5
modeling_internlm.py
CHANGED
@@ -48,6 +48,20 @@ logger = logging.get_logger(__name__)
|
|
48 |
|
49 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def _get_unpad_data(attention_mask):
|
52 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
53 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
@@ -438,13 +452,11 @@ class InternLMFlashAttention2(InternLMAttention):
|
|
438 |
softmax_scale (`float`, *optional*):
|
439 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
440 |
"""
|
441 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
442 |
-
from flash_attn.bert_padding import pad_input
|
443 |
# Contains at least one padding token in the sequence
|
444 |
causal = self.is_causal and query_length != 1
|
445 |
if attention_mask is not None:
|
446 |
batch_size = query_states.shape[0]
|
447 |
-
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self.
|
448 |
query_states, key_states, value_states, attention_mask, query_length
|
449 |
)
|
450 |
|
@@ -472,8 +484,7 @@ class InternLMFlashAttention2(InternLMAttention):
|
|
472 |
|
473 |
return attn_output
|
474 |
|
475 |
-
def
|
476 |
-
from flash_attn.bert_padding import index_first_axis, unpad_input
|
477 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
478 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
479 |
|
@@ -762,6 +773,9 @@ class InternLMModel(InternLMPreTrainedModel):
|
|
762 |
|
763 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
764 |
|
|
|
|
|
|
|
765 |
# retrieve input_ids and inputs_embeds
|
766 |
if input_ids is not None and inputs_embeds is not None:
|
767 |
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
|
48 |
|
49 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
50 |
|
51 |
+
flash_attn_func, flash_attn_varlen_func = None, None
|
52 |
+
pad_input, index_first_axis, unpad_input = None, None, None
|
53 |
+
def _import_flash_attn():
|
54 |
+
global flash_attn_func, flash_attn_varlen_func
|
55 |
+
global pad_input, index_first_axis, unpad_input
|
56 |
+
try:
|
57 |
+
from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
|
58 |
+
from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
|
59 |
+
flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
|
60 |
+
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
|
61 |
+
except ImportError:
|
62 |
+
raise ImportError("flash_attn is not installed.")
|
63 |
+
|
64 |
+
|
65 |
def _get_unpad_data(attention_mask):
|
66 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
67 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
452 |
softmax_scale (`float`, *optional*):
|
453 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
454 |
"""
|
|
|
|
|
455 |
# Contains at least one padding token in the sequence
|
456 |
causal = self.is_causal and query_length != 1
|
457 |
if attention_mask is not None:
|
458 |
batch_size = query_states.shape[0]
|
459 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
|
460 |
query_states, key_states, value_states, attention_mask, query_length
|
461 |
)
|
462 |
|
|
|
484 |
|
485 |
return attn_output
|
486 |
|
487 |
+
def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
|
|
488 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
489 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
490 |
|
|
|
773 |
|
774 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
775 |
|
776 |
+
if self.config.attn_implementation == "flash_attention_2":
|
777 |
+
_import_flash_attn()
|
778 |
+
|
779 |
# retrieve input_ids and inputs_embeds
|
780 |
if input_ids is not None and inputs_embeds is not None:
|
781 |
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|