ggcristian commited on
Commit
ee706c5
1 Parent(s): fad288c

Update modeling_phi3_v.py

Browse files
Files changed (1) hide show
  1. modeling_phi3_v.py +11 -0
modeling_phi3_v.py CHANGED
@@ -40,6 +40,7 @@ from transformers.utils import (
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
 
43
  is_flash_attn_greater_or_equal_2_10,
44
  logging,
45
  replace_return_docstrings,
@@ -56,6 +57,16 @@ try:
56
  except ImportError:
57
  pass
58
 
 
 
 
 
 
 
 
 
 
 
59
  logger = logging.get_logger(__name__)
60
 
61
  _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-vision-128k-instruct"
 
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
+ is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
 
57
  except ImportError:
58
  pass
59
 
60
+ try: # noqa: SIM105
61
+ if is_flash_attn_2_available():
62
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
63
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
64
+ except ImportError:
65
+ # Workaround for https://github.com/huggingface/transformers/issues/28459,
66
+ # don't move to contextlib.suppress(ImportError)
67
+ pass
68
+
69
+
70
  logger = logging.get_logger(__name__)
71
 
72
  _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-vision-128k-instruct"