yangapku commited on
Commit
50ea631
1 Parent(s): 5c611a5

add support for flash attn 2

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +40 -31
modeling_qwen.py CHANGED
@@ -36,10 +36,6 @@ SUPPORT_CUDA = torch.cuda.is_available()
36
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
 
39
- apply_rotary_emb_func = None
40
- rms_norm = None
41
- flash_attn_unpadded_func = None
42
-
43
  from .configuration_qwen import QWenConfig
44
  from .qwen_generation_utils import (
45
  HistoryType,
@@ -57,6 +53,45 @@ _CONFIG_FOR_DOC = "QWenConfig"
57
 
58
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class FlashSelfAttention(torch.nn.Module):
61
  def __init__(
62
  self,
@@ -794,33 +829,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
794
  logger.warn("Flash attention will be disabled because it does NOT support fp32.")
795
 
796
  if config.use_flash_attn:
797
- global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
798
- try:
799
- from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
800
- apply_rotary_emb_func = __apply_rotary_emb_func
801
- except ImportError:
802
- logger.warn(
803
- "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
804
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
805
- )
806
-
807
- try:
808
- from flash_attn.ops.rms_norm import rms_norm as __rms_norm
809
- rms_norm = __rms_norm
810
- except ImportError:
811
- logger.warn(
812
- "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
813
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
814
- )
815
-
816
- try:
817
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
818
- flash_attn_unpadded_func = __flash_attn_unpadded_func
819
- except ImportError:
820
- logger.warn(
821
- "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
822
- "https://github.com/Dao-AILab/flash-attention"
823
- )
824
 
825
  self.transformer = QWenModel(config)
826
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
 
36
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
 
 
 
 
 
39
  from .configuration_qwen import QWenConfig
40
  from .qwen_generation_utils import (
41
  HistoryType,
 
53
 
54
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
55
 
56
+ apply_rotary_emb_func = None
57
+ rms_norm = None
58
+ flash_attn_unpadded_func = None
59
+
60
+
61
+ def _import_flash_attn():
62
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
63
+ try:
64
+ from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
65
+ apply_rotary_emb_func = __apply_rotary_emb_func
66
+ except ImportError:
67
+ logger.warn(
68
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
69
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
70
+ )
71
+
72
+ try:
73
+ from flash_attn.ops.rms_norm import rms_norm as __rms_norm
74
+ rms_norm = __rms_norm
75
+ except ImportError:
76
+ logger.warn(
77
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
78
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
79
+ )
80
+
81
+ try:
82
+ import flash_attn
83
+ if int(flash_attn.__version__.split(".")[0]) >= 2:
84
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
85
+ else:
86
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
87
+ flash_attn_unpadded_func = __flash_attn_unpadded_func
88
+ except ImportError:
89
+ logger.warn(
90
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
91
+ "https://github.com/Dao-AILab/flash-attention"
92
+ )
93
+
94
+
95
  class FlashSelfAttention(torch.nn.Module):
96
  def __init__(
97
  self,
 
829
  logger.warn("Flash attention will be disabled because it does NOT support fp32.")
830
 
831
  if config.use_flash_attn:
832
+ _import_flash_attn()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833
 
834
  self.transformer = QWenModel(config)
835
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)