baohuynhbk14 commited on
Commit
412554a
·
1 Parent(s): 1993f10

Install flash-attn package and set default device to CUDA in app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -29,6 +29,11 @@ import traceback
29
  # import torch
30
  from conversation import Conversation
31
  from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
32
 
33
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
34
 
 
29
  # import torch
30
  from conversation import Conversation
31
  from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
32
+ import subprocess
33
+
34
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
35
+
36
+ torch.set_default_device('cuda')
37
 
38
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
39