visheratin
commited on
Update new model
Browse files- modeling_phi.py +54 -38
modeling_phi.py
CHANGED
@@ -24,11 +24,14 @@ try:
|
|
24 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
25 |
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
26 |
from flash_attn.ops.fused_dense import FusedDense
|
27 |
-
|
|
|
|
|
28 |
pad_input, unpad_input = None, None
|
29 |
FlashRotaryEmbedding = None
|
30 |
FlashSelfAttention, FlashCrossAttention = None, None
|
31 |
FusedDense = None
|
|
|
32 |
|
33 |
|
34 |
@dataclass
|
@@ -525,7 +528,7 @@ class MHA(nn.Module):
|
|
525 |
softmax_scale: Optional[float] = None,
|
526 |
layer_idx: Optional[int] = None,
|
527 |
return_residual: bool = False,
|
528 |
-
checkpointing: bool =
|
529 |
) -> None:
|
530 |
super().__init__()
|
531 |
|
@@ -607,7 +610,7 @@ class MHA(nn.Module):
|
|
607 |
|
608 |
if self.checkpointing:
|
609 |
attn_output = torch.utils.checkpoint.checkpoint(
|
610 |
-
self.inner_attn, qkv, cu_seqlens
|
611 |
)
|
612 |
else:
|
613 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
@@ -616,7 +619,7 @@ class MHA(nn.Module):
|
|
616 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
617 |
|
618 |
if self.checkpointing:
|
619 |
-
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=
|
620 |
|
621 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
622 |
|
@@ -669,11 +672,12 @@ class MHA(nn.Module):
|
|
669 |
self.inner_cross_attn,
|
670 |
q,
|
671 |
kv,
|
672 |
-
causal
|
673 |
-
|
674 |
-
|
675 |
-
cu_seqlens_k
|
676 |
-
max_seqlen_k
|
|
|
677 |
)
|
678 |
else:
|
679 |
attn_output = self.inner_cross_attn(
|
@@ -697,8 +701,9 @@ class MHA(nn.Module):
|
|
697 |
self.inner_cross_attn,
|
698 |
q,
|
699 |
kv,
|
700 |
-
|
701 |
-
|
|
|
702 |
)
|
703 |
|
704 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
@@ -835,7 +840,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
835 |
|
836 |
config_class = PhiConfig
|
837 |
base_model_prefix = "transformer"
|
838 |
-
supports_gradient_checkpointing =
|
839 |
_no_split_modules = ["ParallelBlock"]
|
840 |
|
841 |
def __init__(self, *inputs, **kwargs) -> None:
|
@@ -862,20 +867,20 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
862 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
863 |
**kwargs,
|
864 |
) -> Dict[str, Any]:
|
865 |
-
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
866 |
-
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
else:
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
|
880 |
return {
|
881 |
"input_ids": input_ids,
|
@@ -891,17 +896,19 @@ class PhiModel(PhiPreTrainedModel):
|
|
891 |
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
892 |
|
893 |
def __init__(self, config: PhiConfig) -> None:
|
|
|
|
|
894 |
super().__init__(config)
|
895 |
|
896 |
self.embd = Embedding(config)
|
897 |
self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
|
898 |
-
self.gradient_checkpointing =
|
899 |
self.post_init()
|
900 |
|
901 |
-
def get_input_embeddings(self):
|
902 |
-
return self.embd
|
903 |
|
904 |
-
def set_input_embeddings(self, new_embeddings) -> None:
|
905 |
self.embd.wte = new_embeddings
|
906 |
|
907 |
def forward(
|
@@ -919,11 +926,20 @@ class PhiModel(PhiPreTrainedModel):
|
|
919 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
920 |
|
921 |
for layer in self.h:
|
922 |
-
|
923 |
-
hidden_states
|
924 |
-
|
925 |
-
|
926 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
|
928 |
return hidden_states
|
929 |
|
@@ -947,10 +963,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
947 |
|
948 |
self.post_init()
|
949 |
|
950 |
-
def get_output_embeddings(self):
|
951 |
-
return self.lm_head
|
952 |
|
953 |
-
def set_output_embeddings(self, new_embeddings) -> None:
|
954 |
self.lm_head.linear = new_embeddings
|
955 |
|
956 |
def forward(
|
|
|
24 |
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
25 |
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
|
26 |
from flash_attn.ops.fused_dense import FusedDense
|
27 |
+
print("Using Flash Attention!")
|
28 |
+
except Exception as exc:
|
29 |
+
print(exc)
|
30 |
pad_input, unpad_input = None, None
|
31 |
FlashRotaryEmbedding = None
|
32 |
FlashSelfAttention, FlashCrossAttention = None, None
|
33 |
FusedDense = None
|
34 |
+
print("Not using Flash Attention!")
|
35 |
|
36 |
|
37 |
@dataclass
|
|
|
528 |
softmax_scale: Optional[float] = None,
|
529 |
layer_idx: Optional[int] = None,
|
530 |
return_residual: bool = False,
|
531 |
+
checkpointing: bool = True,
|
532 |
) -> None:
|
533 |
super().__init__()
|
534 |
|
|
|
610 |
|
611 |
if self.checkpointing:
|
612 |
attn_output = torch.utils.checkpoint.checkpoint(
|
613 |
+
self.inner_attn, qkv, None, cu_seqlens, max_seqlen, use_reentrant=False
|
614 |
)
|
615 |
else:
|
616 |
attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
|
|
|
619 |
return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output
|
620 |
|
621 |
if self.checkpointing:
|
622 |
+
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, None, key_padding_mask, use_reentrant=False)
|
623 |
|
624 |
return self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
625 |
|
|
|
672 |
self.inner_cross_attn,
|
673 |
q,
|
674 |
kv,
|
675 |
+
causal,
|
676 |
+
cu_seqlens_q,
|
677 |
+
max_seqlen_q,
|
678 |
+
cu_seqlens_k,
|
679 |
+
max_seqlen_k,
|
680 |
+
use_reentrant=False,
|
681 |
)
|
682 |
else:
|
683 |
attn_output = self.inner_cross_attn(
|
|
|
701 |
self.inner_cross_attn,
|
702 |
q,
|
703 |
kv,
|
704 |
+
causal,
|
705 |
+
key_padding_mask,
|
706 |
+
use_reentrant=False,
|
707 |
)
|
708 |
|
709 |
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
|
|
|
840 |
|
841 |
config_class = PhiConfig
|
842 |
base_model_prefix = "transformer"
|
843 |
+
supports_gradient_checkpointing = True
|
844 |
_no_split_modules = ["ParallelBlock"]
|
845 |
|
846 |
def __init__(self, *inputs, **kwargs) -> None:
|
|
|
867 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
868 |
**kwargs,
|
869 |
) -> Dict[str, Any]:
|
870 |
+
# if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
871 |
+
# past_key_values = InferenceParams(
|
872 |
+
# max_seqlen=self.config.n_positions,
|
873 |
+
# max_batch_size=input_ids.shape[0],
|
874 |
+
# seqlen_offset=0,
|
875 |
+
# batch_size_offset=0,
|
876 |
+
# key_value_memory_dict={},
|
877 |
+
# lengths_per_sample=None,
|
878 |
+
# )
|
879 |
+
# else:
|
880 |
+
# # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
881 |
+
# past_key_values.seqlen_offset = input_ids.shape[1] - 1
|
882 |
+
# input_ids = input_ids[:, -1].unsqueeze(-1)
|
883 |
+
# attention_mask = attention_mask[:, -1].unsqueeze(-1)
|
884 |
|
885 |
return {
|
886 |
"input_ids": input_ids,
|
|
|
896 |
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"]
|
897 |
|
898 |
def __init__(self, config: PhiConfig) -> None:
|
899 |
+
config.flash_attn = True
|
900 |
+
config.flash_rotary = True
|
901 |
super().__init__(config)
|
902 |
|
903 |
self.embd = Embedding(config)
|
904 |
self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)])
|
905 |
+
self.gradient_checkpointing = True
|
906 |
self.post_init()
|
907 |
|
908 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
909 |
+
return self.embd.wte
|
910 |
|
911 |
+
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
|
912 |
self.embd.wte = new_embeddings
|
913 |
|
914 |
def forward(
|
|
|
926 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
927 |
|
928 |
for layer in self.h:
|
929 |
+
if self.gradient_checkpointing:
|
930 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
931 |
+
layer.__call__,
|
932 |
+
hidden_states,
|
933 |
+
past_key_values,
|
934 |
+
attention_mask,
|
935 |
+
use_reentrant=False,
|
936 |
+
)
|
937 |
+
else:
|
938 |
+
hidden_states = layer(
|
939 |
+
hidden_states,
|
940 |
+
past_key_values=past_key_values,
|
941 |
+
attention_mask=attention_mask,
|
942 |
+
)
|
943 |
|
944 |
return hidden_states
|
945 |
|
|
|
963 |
|
964 |
self.post_init()
|
965 |
|
966 |
+
def get_output_embeddings(self) -> nn.Linear:
|
967 |
+
return self.lm_head.linear
|
968 |
|
969 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
970 |
self.lm_head.linear = new_embeddings
|
971 |
|
972 |
def forward(
|