Spaces:
Running
on
A100
Running
on
A100
Flash Attention 2 only supports fp16
Browse files- inference/infer.py +1 -1
inference/infer.py
CHANGED
@@ -76,7 +76,7 @@ print(f"Using device: {device}")
|
|
76 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
77 |
model = AutoModelForCausalLM.from_pretrained(
|
78 |
stage1_model,
|
79 |
-
torch_dtype=torch.
|
80 |
attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
|
81 |
)
|
82 |
model.to(device)
|
|
|
76 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
77 |
model = AutoModelForCausalLM.from_pretrained(
|
78 |
stage1_model,
|
79 |
+
torch_dtype=torch.float16,
|
80 |
attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
|
81 |
)
|
82 |
model.to(device)
|