Spaces:
Sleeping
Sleeping
0.19 implementing flash_attn
Browse files- app.py +9 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
import gradio as gr
|
5 |
import logging
|
6 |
from huggingface_hub import login
|
|
|
7 |
|
8 |
import os
|
9 |
import traceback
|
@@ -66,6 +67,10 @@ def load_model_a(model_id):
|
|
66 |
device_map="auto",
|
67 |
trust_remote_code=True,
|
68 |
).eval()
|
|
|
|
|
|
|
|
|
69 |
except Exception as e:
|
70 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
71 |
|
@@ -83,6 +88,10 @@ def load_model_b(model_id):
|
|
83 |
device_map="auto",
|
84 |
trust_remote_code=True,
|
85 |
).eval()
|
|
|
|
|
|
|
|
|
86 |
except Exception as e:
|
87 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
88 |
return gr.update(label=model_id)
|
|
|
4 |
import gradio as gr
|
5 |
import logging
|
6 |
from huggingface_hub import login
|
7 |
+
from flash_attn.flash_attention import FlashAttention
|
8 |
|
9 |
import os
|
10 |
import traceback
|
|
|
67 |
device_map="auto",
|
68 |
trust_remote_code=True,
|
69 |
).eval()
|
70 |
+
for name, module in model_a.named_modules():
|
71 |
+
if isinstance(module, torch.nn.MultiheadAttention):
|
72 |
+
module.forward = FlashAttention(module.embed_dim)
|
73 |
+
logging.debug(f'{SPACER} forwarding module of {model_id_a} to flash_attn')
|
74 |
except Exception as e:
|
75 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
76 |
|
|
|
88 |
device_map="auto",
|
89 |
trust_remote_code=True,
|
90 |
).eval()
|
91 |
+
for name, module in model_b.named_modules():
|
92 |
+
if isinstance(module, torch.nn.MultiheadAttention):
|
93 |
+
module.forward = FlashAttention(module.embed_dim)
|
94 |
+
logging.debug(f'{SPACER} forwarding module of {model_id_b} to flash_attn')
|
95 |
except Exception as e:
|
96 |
logging.error(f'{SPACER} Error: {e}, Traceback {traceback.format_exc()}')
|
97 |
return gr.update(label=model_id)
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ accelerate==0.33.0
|
|
5 |
sentencepiece==0.2.0
|
6 |
spaces==0.29.2
|
7 |
gradio==4.39.0
|
8 |
-
bitsandbytes==0.43.2
|
|
|
|
5 |
sentencepiece==0.2.0
|
6 |
spaces==0.29.2
|
7 |
gradio==4.39.0
|
8 |
+
bitsandbytes==0.43.2
|
9 |
+
flash-attn
|