tomer-deci commited on
Commit
9817f85
·
1 Parent(s): 761ba47

Update transformers_v4_35_2__modeling_llama.py

Browse files
transformers_v4_35_2__modeling_llama.py CHANGED
@@ -43,14 +43,11 @@ from transformers.utils import (
43
  from transformers.utils.import_utils import is_torch_fx_available
44
  from .transformers_v4_35_2__configuration_llama import LlamaConfig
45
 
 
 
 
 
46
 
47
- if is_flash_attn_2_available():
48
- def import_flash_attn():
49
- from flash_attn import flash_attn_func, flash_attn_varlen_func
50
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
- return flash_attn_func, flash_attn_varlen_func, index_first_axis, pad_input, unpad_input
52
-
53
- flash_attn_func, flash_attn_varlen_func, index_first_axis, pad_input, unpad_input = import_flash_attn()
54
 
55
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
56
  # It means that the function will not be traced through and simply appear as a node in the graph.
 
43
  from transformers.utils.import_utils import is_torch_fx_available
44
  from .transformers_v4_35_2__configuration_llama import LlamaConfig
45
 
46
+ # Deci: commented out to prevent unnecessary dependency
47
+ # if is_flash_attn_2_available():
48
+ # from flash_attn import flash_attn_func, flash_attn_varlen_func
49
+ # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
50
 
 
 
 
 
 
 
 
51
 
52
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
53
  # It means that the function will not be traced through and simply appear as a node in the graph.