KingNish commited on
Commit
e5ae04a
·
verified ·
1 Parent(s): c9e6e07

Flash Attention 2 only supports fp16

Browse files
Files changed (1) hide show
  1. inference/infer.py +1 -1
inference/infer.py CHANGED
@@ -76,7 +76,7 @@ print(f"Using device: {device}")
76
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
77
  model = AutoModelForCausalLM.from_pretrained(
78
  stage1_model,
79
- torch_dtype=torch.bfloat16,
80
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
81
  )
82
  model.to(device)
 
76
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
77
  model = AutoModelForCausalLM.from_pretrained(
78
  stage1_model,
79
+ torch_dtype=torch.float16,
80
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
81
  )
82
  model.to(device)