Spaces:
Sleeping
Sleeping
flash attn fix
Browse files- app.py +13 -1
- requirements.txt +1 -2
- utils.py +17 -0
app.py
CHANGED
@@ -10,6 +10,8 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
10 |
import cv2
|
11 |
import traceback
|
12 |
import matplotlib.pyplot as plt
|
|
|
|
|
13 |
|
14 |
# CUDA optimizations
|
15 |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
@@ -26,9 +28,19 @@ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
|
|
26 |
image_predictor = SAM2ImagePredictor(sam2_model)
|
27 |
|
28 |
model_id = 'microsoft/Florence-2-large'
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
31 |
|
|
|
32 |
def apply_color_mask(frame, mask, obj_id):
|
33 |
cmap = plt.get_cmap("tab10")
|
34 |
color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
|
|
|
10 |
import cv2
|
11 |
import traceback
|
12 |
import matplotlib.pyplot as plt
|
13 |
+
from utils import load_model_without_flash_attn
|
14 |
+
|
15 |
|
16 |
# CUDA optimizations
|
17 |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
|
28 |
image_predictor = SAM2ImagePredictor(sam2_model)
|
29 |
|
30 |
model_id = 'microsoft/Florence-2-large'
|
31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
+
|
33 |
+
def load_florence_model():
|
34 |
+
return AutoModelForCausalLM.from_pretrained(
|
35 |
+
model_id,
|
36 |
+
trust_remote_code=True,
|
37 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
38 |
+
).eval().to(device)
|
39 |
+
|
40 |
+
florence_model = load_model_without_flash_attn(load_florence_model)
|
41 |
florence_processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
42 |
|
43 |
+
|
44 |
def apply_color_mask(frame, mask, obj_id):
|
45 |
cmap = plt.get_cmap("tab10")
|
46 |
color = np.array(cmap(obj_id % 10)[:3]) # Use modulo 10 to cycle through colors
|
requirements.txt
CHANGED
@@ -8,5 +8,4 @@ opencv-python
|
|
8 |
matplotlib
|
9 |
einops
|
10 |
timm
|
11 |
-
pytest
|
12 |
-
flash_attn
|
|
|
8 |
matplotlib
|
9 |
einops
|
10 |
timm
|
11 |
+
pytest
|
|
utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from unittest.mock import patch
|
3 |
+
from transformers.dynamic_module_utils import get_imports
|
4 |
+
|
5 |
+
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
6 |
+
"""Workaround for flash_attn import issue."""
|
7 |
+
if not str(filename).endswith(("modeling_phi.py", "configuration_florence2.py")):
|
8 |
+
return get_imports(filename)
|
9 |
+
imports = get_imports(filename)
|
10 |
+
if "flash_attn" in imports:
|
11 |
+
imports.remove("flash_attn")
|
12 |
+
return imports
|
13 |
+
|
14 |
+
def load_model_without_flash_attn(model_loader):
|
15 |
+
"""Load a model using the flash_attn workaround."""
|
16 |
+
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
17 |
+
return model_loader()
|