Spaces:
Running
on
Zero
Running
on
Zero
UPD: added setup.py for installation
Browse files- SegmentAnything2AssistApp.py +1 -1
- setup.py +25 -0
- src/{YOLOv10Plugin.py β SegmentAnything2Assist/Plugin/YOLOv10Plugin.py} +0 -0
- src/{__init__.py β SegmentAnything2Assist/Plugin/__init__.py} +0 -0
- src/{SegmentAnything2Assist.py β SegmentAnything2Assist/SegmentAnything2Assist.py} +11 -7
- src/SegmentAnything2Assist/__init__.py +0 -0
- test/assets/liberty.jpg +0 -0
- test/test_module.py +59 -0
SegmentAnything2AssistApp.py
CHANGED
@@ -4,7 +4,7 @@ import gradio_imageslider
|
|
4 |
import spaces
|
5 |
import torch
|
6 |
|
7 |
-
import src.SegmentAnything2Assist as SegmentAnything2Assist
|
8 |
|
9 |
example_image_annotation = {
|
10 |
"image": "assets/cars.jpg",
|
|
|
4 |
import spaces
|
5 |
import torch
|
6 |
|
7 |
+
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
8 |
|
9 |
example_image_annotation = {
|
10 |
"image": "assets/cars.jpg",
|
setup.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="SegmentAnything2Assist",
|
5 |
+
version="0.1",
|
6 |
+
packages=find_packages(where="src"),
|
7 |
+
package_dir={"": "src"},
|
8 |
+
install_requires=[
|
9 |
+
"SAM-2 @ git+https://github.com/facebookresearch/segment-anything-2.git@7e1596c0b6462eb1d1ba7e1492430fed95023598",
|
10 |
+
"ultralytics @ git+https://github.com/THU-MIG/yolov10.git@cd2f79c70299c9041fb6d19617ef1296f47575b1",
|
11 |
+
"opencv-python==4.10.0.84",
|
12 |
+
],
|
13 |
+
author="xqt",
|
14 |
+
author_email="xqt@users.noreply.huggingface.co",
|
15 |
+
description="A package to segment anything and assist in the process",
|
16 |
+
long_description=open("README.md").read(),
|
17 |
+
long_description_content_type="text/markdown",
|
18 |
+
url="https://huggingface.co/spaces/xqt/Segment-Anything-2-Assist",
|
19 |
+
classifiers=[
|
20 |
+
"Programming Language :: Python :: 3",
|
21 |
+
"License :: OSI Approved :: MIT License",
|
22 |
+
"Operating System :: OS Independent",
|
23 |
+
],
|
24 |
+
python_requires=">=3.8.0",
|
25 |
+
)
|
src/{YOLOv10Plugin.py β SegmentAnything2Assist/Plugin/YOLOv10Plugin.py}
RENAMED
File without changes
|
src/{__init__.py β SegmentAnything2Assist/Plugin/__init__.py}
RENAMED
File without changes
|
src/{SegmentAnything2Assist.py β SegmentAnything2Assist/SegmentAnything2Assist.py}
RENAMED
@@ -5,12 +5,11 @@ import tqdm
|
|
5 |
import requests
|
6 |
import torch
|
7 |
import numpy
|
8 |
-
import pickle
|
9 |
|
10 |
import sam2.build_sam
|
11 |
import sam2.automatic_mask_generator
|
12 |
|
13 |
-
from . import YOLOv10Plugin
|
14 |
|
15 |
import cv2
|
16 |
|
@@ -122,14 +121,17 @@ class SegmentAnything2Assist:
|
|
122 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
123 |
return ret
|
124 |
|
125 |
-
def load_model(self) ->
|
126 |
if self.is_model_available():
|
127 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
|
|
128 |
|
129 |
-
|
|
|
|
|
130 |
if not force and self.is_model_available():
|
131 |
print(f"{self.model_path} already exists. Skipping download.")
|
132 |
-
return
|
133 |
|
134 |
response = requests.get(self.download_url, stream=True)
|
135 |
total_size = int(response.headers.get("content-length", 0))
|
@@ -141,10 +143,12 @@ class SegmentAnything2Assist:
|
|
141 |
file.write(data)
|
142 |
progress_bar.update(len(data))
|
143 |
|
|
|
|
|
144 |
def generate_automatic_masks(
|
145 |
self,
|
146 |
-
image,
|
147 |
-
points_per_side=
|
148 |
points_per_batch=32,
|
149 |
pred_iou_thresh=0.8,
|
150 |
stability_score_thresh=0.95,
|
|
|
5 |
import requests
|
6 |
import torch
|
7 |
import numpy
|
|
|
8 |
|
9 |
import sam2.build_sam
|
10 |
import sam2.automatic_mask_generator
|
11 |
|
12 |
+
from .Plugin import YOLOv10Plugin
|
13 |
|
14 |
import cv2
|
15 |
|
|
|
121 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
122 |
return ret
|
123 |
|
124 |
+
def load_model(self) -> bool:
|
125 |
if self.is_model_available():
|
126 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
127 |
+
return True
|
128 |
|
129 |
+
return False
|
130 |
+
|
131 |
+
def download_model(self, force: bool = False) -> bool:
|
132 |
if not force and self.is_model_available():
|
133 |
print(f"{self.model_path} already exists. Skipping download.")
|
134 |
+
return False
|
135 |
|
136 |
response = requests.get(self.download_url, stream=True)
|
137 |
total_size = int(response.headers.get("content-length", 0))
|
|
|
143 |
file.write(data)
|
144 |
progress_bar.update(len(data))
|
145 |
|
146 |
+
return True
|
147 |
+
|
148 |
def generate_automatic_masks(
|
149 |
self,
|
150 |
+
image: numpy.ndarray,
|
151 |
+
points_per_side=10,
|
152 |
points_per_batch=32,
|
153 |
pred_iou_thresh=0.8,
|
154 |
stability_score_thresh=0.95,
|
src/SegmentAnything2Assist/__init__.py
ADDED
File without changes
|
test/assets/liberty.jpg
ADDED
test/test_module.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
class TestSegmentAnything2Assist(unittest.TestCase):
|
7 |
+
def setUp(self) -> None:
|
8 |
+
return super().setUp()
|
9 |
+
|
10 |
+
def tearDown(self) -> None:
|
11 |
+
return super().tearDown()
|
12 |
+
|
13 |
+
def _loading_all_sam_model_types(self):
|
14 |
+
# Test loading all types of SAM2 models.
|
15 |
+
all_sam_models_type = [
|
16 |
+
"sam2_hiera_tiny",
|
17 |
+
"sam2_hiera_small",
|
18 |
+
"sam2_hiera_base_plus",
|
19 |
+
"sam2_hiera_large",
|
20 |
+
]
|
21 |
+
for sam_model_type in all_sam_models_type:
|
22 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
23 |
+
sam_model_name=sam_model_type, download=True, device="cpu"
|
24 |
+
)
|
25 |
+
self.assertEqual(sam_model.is_model_available(), True)
|
26 |
+
|
27 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
28 |
+
sam_model_name=sam_model_type,
|
29 |
+
download=False,
|
30 |
+
model_path=f".tmp/checkpoints/{sam_model_type}.pth",
|
31 |
+
device="cpu",
|
32 |
+
)
|
33 |
+
|
34 |
+
with self.assertRaises(Exception):
|
35 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
36 |
+
sam_model_name=sam_model_type,
|
37 |
+
download=False,
|
38 |
+
model_path=".",
|
39 |
+
device="cpu",
|
40 |
+
)
|
41 |
+
|
42 |
+
def test_generate_automatic_mask(self):
|
43 |
+
image = cv2.imread("test/assets/liberty.jpg")
|
44 |
+
|
45 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
46 |
+
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
47 |
+
)
|
48 |
+
|
49 |
+
masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image)
|
50 |
+
|
51 |
+
print(type(masks[0]))
|
52 |
+
print(type(segmentation_masks[0]))
|
53 |
+
print(type(bboxes[0]))
|
54 |
+
|
55 |
+
self.assertEqual(len(masks), len(segmentation_masks))
|
56 |
+
self.assertEqual(len(masks), len(bboxes))
|
57 |
+
|
58 |
+
# for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes):
|
59 |
+
self.assertEqual(segmentation_masks[0].shape, image.shape)
|