Spaces:
Runtime error
Runtime error
VikramSingh178
commited on
Commit
•
88e9206
1
Parent(s):
a76141d
refactor: Update import statement for accelerator and image augmentation functionality
Browse files- scripts/__pycache__/config.cpython-312.pyc +0 -0
- scripts/utils.py +43 -19
scripts/__pycache__/config.cpython-312.pyc
ADDED
Binary file (3.22 kB). View file
|
|
scripts/utils.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import torch
|
2 |
from ultralytics import YOLO
|
3 |
-
from transformers import SamModel,SamProcessor
|
4 |
import numpy as np
|
5 |
-
from PIL import Image
|
6 |
from config import SEGMENTATION_MODEL_NAME
|
7 |
|
8 |
|
@@ -14,15 +14,14 @@ def accelerator():
|
|
14 |
str: The name of the device accelerator ('cuda', 'mps', or 'cpu').
|
15 |
"""
|
16 |
if torch.cuda.is_available():
|
17 |
-
device =
|
18 |
elif torch.backends.mps.is_available():
|
19 |
-
device =
|
20 |
else:
|
21 |
-
device =
|
22 |
return device
|
23 |
|
24 |
|
25 |
-
|
26 |
class ImageAugmentation:
|
27 |
"""
|
28 |
Class for centering an image on a white background using ROI.
|
@@ -54,7 +53,10 @@ class ImageAugmentation:
|
|
54 |
w, h = self.background_size
|
55 |
bg = np.ones((h, w, 3), dtype=np.uint8) * 255 # White background
|
56 |
x, y, roi_w, roi_h = roi
|
57 |
-
bg[
|
|
|
|
|
|
|
58 |
return bg
|
59 |
|
60 |
def detect_region_of_interest(self, image):
|
@@ -69,11 +71,12 @@ class ImageAugmentation:
|
|
69 |
"""
|
70 |
# Convert image to grayscale
|
71 |
grayscale_image = np.array(Image.fromarray(image).convert("L"))
|
72 |
-
|
73 |
# Calculate bounding box of non-zero region
|
74 |
bbox = Image.fromarray(grayscale_image).getbbox()
|
75 |
return bbox
|
76 |
|
|
|
77 |
def generate_bbox(image):
|
78 |
"""
|
79 |
Generate bounding box for the input image.
|
@@ -85,17 +88,39 @@ def generate_bbox(image):
|
|
85 |
tuple: Bounding box coordinates (x, y, width, height).
|
86 |
"""
|
87 |
# Load YOLOv5 model
|
88 |
-
model = YOLO("yolov8s.pt")
|
89 |
results = model(image)
|
90 |
# Get bounding box coordinates
|
91 |
bbox = results[0].boxes.xyxy.int().tolist()
|
92 |
return bbox
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
|
101 |
if __name__ == "__main__":
|
@@ -104,8 +129,7 @@ if __name__ == "__main__":
|
|
104 |
image = np.array(Image.open(image_path).convert("RGB"))
|
105 |
roi = augmenter.detect_region_of_interest(image)
|
106 |
centered_image = augmenter.center_image_on_background(image, roi)
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
1 |
import torch
|
2 |
from ultralytics import YOLO
|
3 |
+
from transformers import SamModel, SamProcessor
|
4 |
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
from config import SEGMENTATION_MODEL_NAME
|
7 |
|
8 |
|
|
|
14 |
str: The name of the device accelerator ('cuda', 'mps', or 'cpu').
|
15 |
"""
|
16 |
if torch.cuda.is_available():
|
17 |
+
device = "cuda"
|
18 |
elif torch.backends.mps.is_available():
|
19 |
+
device = "mps"
|
20 |
else:
|
21 |
+
device = "cpu"
|
22 |
return device
|
23 |
|
24 |
|
|
|
25 |
class ImageAugmentation:
|
26 |
"""
|
27 |
Class for centering an image on a white background using ROI.
|
|
|
53 |
w, h = self.background_size
|
54 |
bg = np.ones((h, w, 3), dtype=np.uint8) * 255 # White background
|
55 |
x, y, roi_w, roi_h = roi
|
56 |
+
bg[
|
57 |
+
(h - roi_h) // 2 : (h - roi_h) // 2 + roi_h,
|
58 |
+
(w - roi_w) // 2 : (w - roi_w) // 2 + roi_w,
|
59 |
+
] = image
|
60 |
return bg
|
61 |
|
62 |
def detect_region_of_interest(self, image):
|
|
|
71 |
"""
|
72 |
# Convert image to grayscale
|
73 |
grayscale_image = np.array(Image.fromarray(image).convert("L"))
|
74 |
+
|
75 |
# Calculate bounding box of non-zero region
|
76 |
bbox = Image.fromarray(grayscale_image).getbbox()
|
77 |
return bbox
|
78 |
|
79 |
+
|
80 |
def generate_bbox(image):
|
81 |
"""
|
82 |
Generate bounding box for the input image.
|
|
|
88 |
tuple: Bounding box coordinates (x, y, width, height).
|
89 |
"""
|
90 |
# Load YOLOv5 model
|
91 |
+
model = YOLO("../models/yolov8s.pt")
|
92 |
results = model(image)
|
93 |
# Get bounding box coordinates
|
94 |
bbox = results[0].boxes.xyxy.int().tolist()
|
95 |
return bbox
|
96 |
|
97 |
+
|
98 |
+
def generate_mask(image):
|
99 |
+
"""
|
100 |
+
Generates masks for the given image using a segmentation model.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
image: The input image for which masks need to be generated.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
masks: A tensor containing the generated masks.
|
107 |
+
|
108 |
+
Raises:
|
109 |
+
None
|
110 |
+
"""
|
111 |
+
model = SamModel.from_pretrained(SEGMENTATION_MODEL_NAME).to(device=accelerator())
|
112 |
+
processor = SamProcessor.from_pretrained(SEGMENTATION_MODEL_NAME)
|
113 |
+
inputs = processor(
|
114 |
+
image, input_boxes=[generate_bbox(image)], return_tensors="pt"
|
115 |
+
).to(torch.float)
|
116 |
+
inputs.to(device=accelerator())
|
117 |
+
outputs = model(**inputs)
|
118 |
+
mask = processor.image_processor.post_process_masks(
|
119 |
+
outputs.pred_masks.cpu(),
|
120 |
+
inputs["original_sizes"].cpu(),
|
121 |
+
inputs["reshaped_input_sizes"].cpu(),
|
122 |
+
)
|
123 |
+
return mask
|
124 |
|
125 |
|
126 |
if __name__ == "__main__":
|
|
|
129 |
image = np.array(Image.open(image_path).convert("RGB"))
|
130 |
roi = augmenter.detect_region_of_interest(image)
|
131 |
centered_image = augmenter.center_image_on_background(image, roi)
|
132 |
+
masks = generate_mask(Image.fromarray(centered_image))
|
133 |
+
masks = np.array(masks)
|
134 |
+
mask_image = Image.fromarray(masks[0])
|
135 |
+
mask_image.save("mask.jpg")
|
|