jupyterjazz
commited on
Commit
•
8b64fa8
1
Parent(s):
f2e0e62
chore: remove parallelmha
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
mha.py
CHANGED
@@ -7,8 +7,6 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
from einops import rearrange, repeat
|
9 |
|
10 |
-
from flash_attn.utils.distributed import get_dim_for_local_rank
|
11 |
-
|
12 |
try:
|
13 |
from flash_attn import (
|
14 |
flash_attn_kvpacked_func,
|
@@ -706,316 +704,3 @@ class MHA(nn.Module):
|
|
706 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
707 |
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
708 |
return out if not self.return_residual else (out, x)
|
709 |
-
|
710 |
-
|
711 |
-
class ParallelMHA(nn.Module):
|
712 |
-
"""Multi-head self-attention and cross-attention"""
|
713 |
-
|
714 |
-
def __init__(
|
715 |
-
self,
|
716 |
-
embed_dim,
|
717 |
-
num_heads,
|
718 |
-
process_group,
|
719 |
-
num_heads_kv=None,
|
720 |
-
qkv_proj_bias=True,
|
721 |
-
out_proj_bias=True,
|
722 |
-
dropout=0.0,
|
723 |
-
softmax_scale=None,
|
724 |
-
causal=False,
|
725 |
-
layer_idx=None,
|
726 |
-
rotary_emb_dim=0,
|
727 |
-
rotary_emb_base=10000.0,
|
728 |
-
rotary_emb_scale_base=None,
|
729 |
-
rotary_emb_interleaved=False,
|
730 |
-
use_alibi=False,
|
731 |
-
use_flash_attn=False,
|
732 |
-
checkpointing=False,
|
733 |
-
sequence_parallel=True,
|
734 |
-
device=None,
|
735 |
-
dtype=None,
|
736 |
-
) -> None:
|
737 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
738 |
-
super().__init__()
|
739 |
-
self.embed_dim = embed_dim
|
740 |
-
self.causal = causal
|
741 |
-
self.layer_idx = layer_idx
|
742 |
-
self.rotary_emb_dim = rotary_emb_dim
|
743 |
-
self.use_flash_attn = use_flash_attn
|
744 |
-
self.checkpointing = checkpointing
|
745 |
-
self.process_group = process_group
|
746 |
-
self.world_size = process_group.size()
|
747 |
-
self.local_rank = torch.distributed.get_rank(process_group)
|
748 |
-
|
749 |
-
self.num_heads = num_heads
|
750 |
-
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
751 |
-
|
752 |
-
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
753 |
-
assert (
|
754 |
-
self.num_heads % self.num_heads_kv == 0
|
755 |
-
), "num_heads must be divisible by num_heads_kv"
|
756 |
-
|
757 |
-
self.num_heads_per_rank = get_dim_for_local_rank(
|
758 |
-
self.num_heads, self.world_size, self.local_rank
|
759 |
-
)
|
760 |
-
self.num_heads_kv_per_rank = get_dim_for_local_rank(
|
761 |
-
self.num_heads_kv, self.world_size, self.local_rank
|
762 |
-
)
|
763 |
-
self.head_dim = self.embed_dim // num_heads
|
764 |
-
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
765 |
-
|
766 |
-
if use_alibi:
|
767 |
-
assert use_flash_attn, "ALiBi code path requires flash_attn"
|
768 |
-
num_heads_local = math.ceil(self.num_heads / self.world_size)
|
769 |
-
alibi_slopes = torch.tensor(
|
770 |
-
get_alibi_slopes(num_heads)[
|
771 |
-
self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
|
772 |
-
],
|
773 |
-
device=device,
|
774 |
-
)
|
775 |
-
else:
|
776 |
-
alibi_slopes = None
|
777 |
-
|
778 |
-
if self.rotary_emb_dim > 0:
|
779 |
-
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
780 |
-
self.rotary_emb = RotaryEmbedding(
|
781 |
-
self.rotary_emb_dim,
|
782 |
-
base=rotary_emb_base,
|
783 |
-
scale_base=rotary_emb_scale_base,
|
784 |
-
interleaved=rotary_emb_interleaved,
|
785 |
-
device=device,
|
786 |
-
)
|
787 |
-
|
788 |
-
if ColumnParallelLinear is None or RowParallelLinear is None:
|
789 |
-
raise ImportError("fused_dense is not installed")
|
790 |
-
self.Wqkv = ColumnParallelLinear(
|
791 |
-
embed_dim,
|
792 |
-
qkv_dim,
|
793 |
-
process_group,
|
794 |
-
bias=qkv_proj_bias,
|
795 |
-
sequence_parallel=sequence_parallel,
|
796 |
-
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
|
797 |
-
**factory_kwargs,
|
798 |
-
)
|
799 |
-
inner_attn_cls = (
|
800 |
-
partial(FlashSelfAttention, alibi_slopes=alibi_slopes)
|
801 |
-
if use_flash_attn
|
802 |
-
else SelfAttention
|
803 |
-
)
|
804 |
-
inner_cross_attn_cls = (
|
805 |
-
partial(FlashCrossAttention, alibi_slopes=alibi_slopes)
|
806 |
-
if use_flash_attn
|
807 |
-
else CrossAttention
|
808 |
-
)
|
809 |
-
self.inner_attn = inner_attn_cls(
|
810 |
-
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
811 |
-
)
|
812 |
-
self.inner_cross_attn = inner_cross_attn_cls(
|
813 |
-
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
814 |
-
)
|
815 |
-
self.out_proj = RowParallelLinear(
|
816 |
-
embed_dim,
|
817 |
-
embed_dim,
|
818 |
-
process_group,
|
819 |
-
bias=out_proj_bias,
|
820 |
-
sequence_parallel=sequence_parallel,
|
821 |
-
multiple_of=self.head_dim,
|
822 |
-
**factory_kwargs,
|
823 |
-
)
|
824 |
-
|
825 |
-
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
826 |
-
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
827 |
-
device = self.out_proj.weight.device
|
828 |
-
return torch.empty(
|
829 |
-
batch_size,
|
830 |
-
max_seqlen,
|
831 |
-
2,
|
832 |
-
self.num_heads_kv_per_rank,
|
833 |
-
self.head_dim,
|
834 |
-
dtype=dtype,
|
835 |
-
device=device,
|
836 |
-
)
|
837 |
-
|
838 |
-
def _update_kv_cache(self, kv, inference_params):
|
839 |
-
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
840 |
-
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
841 |
-
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
842 |
-
|
843 |
-
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
844 |
-
"""
|
845 |
-
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
846 |
-
q: (batch_size, seqlen_q, nheads, head_dim)
|
847 |
-
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
848 |
-
"""
|
849 |
-
assert inference_params is not None and inference_params.seqlen_offset > 0
|
850 |
-
assert self.use_flash_attn
|
851 |
-
if self.rotary_emb_dim > 0:
|
852 |
-
assert self.rotary_emb.scale is None, "This code path does not support xPos"
|
853 |
-
self.rotary_emb._update_cos_sin_cache(
|
854 |
-
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
855 |
-
)
|
856 |
-
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
857 |
-
else:
|
858 |
-
rotary_cos, rotary_sin = None, None
|
859 |
-
batch = q.shape[0]
|
860 |
-
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
861 |
-
cache_seqlens = (
|
862 |
-
inference_params.lengths_per_sample[:batch]
|
863 |
-
if inference_params.lengths_per_sample is not None
|
864 |
-
else inference_params.seqlen_offset
|
865 |
-
)
|
866 |
-
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
867 |
-
context = flash_attn_with_kvcache(
|
868 |
-
q,
|
869 |
-
kv_cache[:, :, 0],
|
870 |
-
kv_cache[:, :, 1],
|
871 |
-
kv[:, :, 0],
|
872 |
-
kv[:, :, 1],
|
873 |
-
rotary_cos=rotary_cos,
|
874 |
-
rotary_sin=rotary_sin,
|
875 |
-
cache_seqlens=cache_seqlens,
|
876 |
-
softmax_scale=self.inner_cross_attn.softmax_scale,
|
877 |
-
causal=self.inner_cross_attn.causal,
|
878 |
-
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
879 |
-
alibi_slopes=alibi_slopes,
|
880 |
-
)
|
881 |
-
return context
|
882 |
-
|
883 |
-
def _update_kvcache_attention(self, q, kv, inference_params):
|
884 |
-
"""Write kv to inference_params, then do attention"""
|
885 |
-
if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
|
886 |
-
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
887 |
-
kv = self._update_kv_cache(kv, inference_params)
|
888 |
-
return self.inner_cross_attn(q, kv)
|
889 |
-
else:
|
890 |
-
batch = q.shape[0]
|
891 |
-
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
892 |
-
cache_seqlens = (
|
893 |
-
inference_params.lengths_per_sample[:batch]
|
894 |
-
if inference_params.lengths_per_sample is not None
|
895 |
-
else inference_params.seqlen_offset
|
896 |
-
)
|
897 |
-
alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
|
898 |
-
context = flash_attn_with_kvcache(
|
899 |
-
q,
|
900 |
-
kv_cache[:, :, 0],
|
901 |
-
kv_cache[:, :, 1],
|
902 |
-
kv[:, :, 0],
|
903 |
-
kv[:, :, 1],
|
904 |
-
cache_seqlens=cache_seqlens,
|
905 |
-
softmax_scale=self.inner_cross_attn.softmax_scale,
|
906 |
-
causal=self.inner_cross_attn.causal,
|
907 |
-
alibi_slopes=alibi_slopes,
|
908 |
-
)
|
909 |
-
return context
|
910 |
-
|
911 |
-
def forward(
|
912 |
-
self, x, seqlen=None, inference_params=None, cu_seqlens=None, max_seqlen=None, **kwargs
|
913 |
-
):
|
914 |
-
"""
|
915 |
-
Arguments:
|
916 |
-
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None and cu_seqlens=None.
|
917 |
-
(seqlen, hidden_dim) if cu_seqlens not None, seqlen equal cu_seqlens[-1].
|
918 |
-
If seqlen is not None and cu_seqlens=None, x is (batch * seqlen, hidden_dim). This is so that when we
|
919 |
-
split x during sequence parallel, we split the batch * seqlen dimension
|
920 |
-
(in case batch is small).
|
921 |
-
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
922 |
-
of the sequences in the batch, used to index into x. Only applicable when using
|
923 |
-
FlashAttention.
|
924 |
-
max_seqlen: int. Maximum sequence length in the batch.
|
925 |
-
"""
|
926 |
-
if cu_seqlens is not None:
|
927 |
-
assert max_seqlen is not None
|
928 |
-
assert seqlen is None
|
929 |
-
assert self.use_flash_attn
|
930 |
-
if inference_params is not None:
|
931 |
-
assert cu_seqlens is None and max_seqlen is None
|
932 |
-
qkv = self.Wqkv(x)
|
933 |
-
if seqlen is not None:
|
934 |
-
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
935 |
-
kwargs = (
|
936 |
-
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
937 |
-
if self.use_flash_attn
|
938 |
-
else kwargs
|
939 |
-
)
|
940 |
-
seqlen_offset = (
|
941 |
-
0
|
942 |
-
if inference_params is None
|
943 |
-
else (
|
944 |
-
inference_params.lengths_per_sample
|
945 |
-
if inference_params.lengths_per_sample is not None
|
946 |
-
else inference_params.seqlen_offset
|
947 |
-
)
|
948 |
-
)
|
949 |
-
rotary_max_seqlen = (
|
950 |
-
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
951 |
-
)
|
952 |
-
if self.num_heads_kv == self.num_heads:
|
953 |
-
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
954 |
-
if (
|
955 |
-
inference_params is None
|
956 |
-
or inference_params.seqlen_offset == 0
|
957 |
-
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
958 |
-
or not self.use_flash_attn
|
959 |
-
):
|
960 |
-
if self.rotary_emb_dim > 0:
|
961 |
-
qkv = self.rotary_emb(
|
962 |
-
qkv,
|
963 |
-
seqlen_offset=seqlen_offset,
|
964 |
-
cu_seqlens=cu_seqlens,
|
965 |
-
max_seqlen=rotary_max_seqlen,
|
966 |
-
)
|
967 |
-
if inference_params is None:
|
968 |
-
if not self.checkpointing:
|
969 |
-
context = self.inner_attn(qkv, **kwargs)
|
970 |
-
else:
|
971 |
-
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
972 |
-
else:
|
973 |
-
context = self._update_kvcache_attention(
|
974 |
-
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
975 |
-
)
|
976 |
-
else:
|
977 |
-
context = self._apply_rotary_update_kvcache_attention(
|
978 |
-
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
979 |
-
)
|
980 |
-
else:
|
981 |
-
q = rearrange(
|
982 |
-
qkv[..., : self.num_heads_per_rank * self.head_dim],
|
983 |
-
"... (h d) -> ... h d",
|
984 |
-
d=self.head_dim,
|
985 |
-
)
|
986 |
-
kv = rearrange(
|
987 |
-
qkv[..., self.num_heads_per_rank * self.head_dim :],
|
988 |
-
"... (two hkv d) -> ... two hkv d",
|
989 |
-
two=2,
|
990 |
-
d=self.head_dim,
|
991 |
-
)
|
992 |
-
if (
|
993 |
-
inference_params is None
|
994 |
-
or inference_params.seqlen_offset == 0
|
995 |
-
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
996 |
-
or not self.use_flash_attn
|
997 |
-
):
|
998 |
-
if self.rotary_emb_dim > 0:
|
999 |
-
q, kv = self.rotary_emb(
|
1000 |
-
q,
|
1001 |
-
kv,
|
1002 |
-
seqlen_offset=seqlen_offset,
|
1003 |
-
cu_seqlens=cu_seqlens,
|
1004 |
-
max_seqlen=rotary_max_seqlen,
|
1005 |
-
)
|
1006 |
-
if inference_params is None:
|
1007 |
-
if not self.checkpointing:
|
1008 |
-
context = self.inner_cross_attn(q, kv, **kwargs)
|
1009 |
-
else:
|
1010 |
-
context = torch.utils.checkpoint.checkpoint(
|
1011 |
-
self.inner_cross_attn, q, kv, **kwargs
|
1012 |
-
)
|
1013 |
-
else:
|
1014 |
-
context = self._update_kvcache_attention(q, kv, inference_params)
|
1015 |
-
else:
|
1016 |
-
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
1017 |
-
context = rearrange(context, "... h d -> ... (h d)")
|
1018 |
-
if seqlen is not None:
|
1019 |
-
context = rearrange(context, "b s d -> (b s) d")
|
1020 |
-
out = self.out_proj(context)
|
1021 |
-
return out
|
|
|
7 |
import torch.nn as nn
|
8 |
from einops import rearrange, repeat
|
9 |
|
|
|
|
|
10 |
try:
|
11 |
from flash_attn import (
|
12 |
flash_attn_kvpacked_func,
|
|
|
704 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
705 |
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
706 |
return out if not self.return_residual else (out, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|