Muennighoff commited on
Commit
a2975b0
1 Parent(s): 900f290

Update modeling_gritlm7b.py

Browse files
Files changed (1) hide show
  1. modeling_gritlm7b.py +6 -4
modeling_gritlm7b.py CHANGED
@@ -46,11 +46,13 @@ from transformers.utils import (
46
  from transformers import MistralConfig
47
 
48
 
 
49
  try:
50
- from flash_attn import flash_attn_func, flash_attn_varlen_func
51
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
52
-
53
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
54
  except:
55
  pass
56
 
 
46
  from transformers import MistralConfig
47
 
48
 
49
+ # transformers has a bug where it will try to import everything from a custom model file unless there's try/except
50
  try:
51
+ if is_flash_attn_2_available():
52
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
53
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
54
+
55
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
56
  except:
57
  pass
58