remove flash_attn imports and usage

#1
Files changed (1) hide show
  1. modeling_minicpm.py +6 -6
modeling_minicpm.py CHANGED
@@ -51,11 +51,11 @@ from transformers.utils.import_utils import is_torch_fx_available
51
  from .configuration_minicpm import MiniCPMConfig
52
  import re
53
 
54
- try:
55
- from flash_attn import flash_attn_func, flash_attn_varlen_func
56
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
- except:
58
- pass
59
 
60
 
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
@@ -755,7 +755,7 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
755
 
756
  MINICPM_ATTENTION_CLASSES = {
757
  "eager": MiniCPMAttention,
758
- "flash_attention_2": MiniCPMFlashAttention2,
759
  "sdpa": MiniCPMSdpaAttention,
760
  }
761
 
 
51
  from .configuration_minicpm import MiniCPMConfig
52
  import re
53
 
54
+ #try:
55
+ # from flash_attn import flash_attn_func, flash_attn_varlen_func
56
+ # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
57
+ #except:
58
+ # pass
59
 
60
 
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
 
755
 
756
  MINICPM_ATTENTION_CLASSES = {
757
  "eager": MiniCPMAttention,
758
+ #"flash_attention_2": MiniCPMFlashAttention2,
759
  "sdpa": MiniCPMSdpaAttention,
760
  }
761