Spaces:
Runtime error
Runtime error
VikramSingh178
commited on
Commit
•
d1a4430
1
Parent(s):
88a381f
chore: Update segmentation model to facebook/sam-vit-large
Browse files- scripts/config.py +1 -1
- scripts/utils.py +34 -23
scripts/config.py
CHANGED
@@ -6,7 +6,7 @@ DATASET_NAME= "hahminlew/kream-product-blip-captions"
|
|
6 |
PROJECT_NAME = "Product Photography"
|
7 |
PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
|
8 |
CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
|
9 |
-
SEGMENTATION_MODEL_NAME = "facebook/sam-vit-
|
10 |
DETECTION_MODEL_NAME = "yolov8s"
|
11 |
|
12 |
|
|
|
6 |
PROJECT_NAME = "Product Photography"
|
7 |
PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
|
8 |
CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
|
9 |
+
SEGMENTATION_MODEL_NAME = "facebook/sam-vit-large"
|
10 |
DETECTION_MODEL_NAME = "yolov8s"
|
11 |
|
12 |
|
scripts/utils.py
CHANGED
@@ -4,6 +4,9 @@ from transformers import SamModel, SamProcessor
|
|
4 |
import numpy as np
|
5 |
from PIL import Image, ImageOps
|
6 |
from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
|
|
|
|
|
|
|
7 |
|
8 |
def accelerator():
|
9 |
"""
|
@@ -49,43 +52,51 @@ class ImageAugmentation:
|
|
49 |
extended_image.paste(resized_image, (paste_x, paste_y))
|
50 |
return extended_image
|
51 |
|
52 |
-
def generate_mask_from_bbox(self,
|
53 |
"""
|
54 |
Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
"""
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
results = yolo(np.array(image))
|
62 |
bboxes = results[0].boxes.xyxy.tolist()
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
# Prepare inputs for SAM
|
67 |
-
inputs = processor(image, input_boxes=[bboxes], return_tensors="pt").to(device=accelerator())
|
68 |
with torch.no_grad():
|
69 |
outputs = model(**inputs)
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
|
76 |
"""
|
77 |
Inverts the given mask image.
|
78 |
"""
|
79 |
-
mask_image = (mask_image * 255).astype(np.uint8)
|
80 |
-
mask_pil = Image.fromarray(mask_image)
|
81 |
|
82 |
-
|
|
|
83 |
return inverted_mask_pil
|
84 |
|
85 |
if __name__ == "__main__":
|
86 |
-
augmenter = ImageAugmentation(target_width=
|
87 |
-
image_path = "/home/product_diffusion_api/sample_data/
|
88 |
image = Image.open(image_path)
|
89 |
extended_image = augmenter.extend_image(image)
|
90 |
-
mask = augmenter.generate_mask_from_bbox(extended_image)
|
91 |
-
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image, ImageOps
|
6 |
from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
|
7 |
+
from diffusers.utils import load_image
|
8 |
+
|
9 |
+
|
10 |
|
11 |
def accelerator():
|
12 |
"""
|
|
|
52 |
extended_image.paste(resized_image, (paste_x, paste_y))
|
53 |
return extended_image
|
54 |
|
55 |
+
def generate_mask_from_bbox(self,image: Image, segmentation_model: str ,detection_model) -> Image:
|
56 |
"""
|
57 |
Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
image_path (str): The path to the input image.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
numpy.ndarray: The generated mask as a NumPy array.
|
64 |
"""
|
65 |
+
|
66 |
+
yolo = YOLO(detection_model)
|
67 |
+
processor = SamProcessor.from_pretrained(segmentation_model)
|
68 |
+
model = SamModel.from_pretrained(segmentation_model).to(device=accelerator())
|
69 |
+
results = yolo(image)
|
|
|
70 |
bboxes = results[0].boxes.xyxy.tolist()
|
71 |
+
input_boxes = [[[bboxes[0]]]]
|
72 |
+
inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
|
|
|
|
|
|
|
73 |
with torch.no_grad():
|
74 |
outputs = model(**inputs)
|
75 |
+
mask = processor.image_processor.post_process_masks(
|
76 |
+
outputs.pred_masks.cpu(),
|
77 |
+
inputs["original_sizes"].cpu(),
|
78 |
+
inputs["reshaped_input_sizes"].cpu()
|
79 |
+
)[0][0][0].numpy()
|
80 |
+
mask_image = Image.fromarray(mask)
|
81 |
+
return mask_image
|
82 |
+
|
83 |
+
|
84 |
|
85 |
def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
|
86 |
"""
|
87 |
Inverts the given mask image.
|
88 |
"""
|
|
|
|
|
89 |
|
90 |
+
|
91 |
+
inverted_mask_pil = ImageOps.invert(mask_image.convert("L"))
|
92 |
return inverted_mask_pil
|
93 |
|
94 |
if __name__ == "__main__":
|
95 |
+
augmenter = ImageAugmentation(target_width=2560, target_height=1440, roi_scale=0.7)
|
96 |
+
image_path = "/home/product_diffusion_api/sample_data/example3.jpg"
|
97 |
image = Image.open(image_path)
|
98 |
extended_image = augmenter.extend_image(image)
|
99 |
+
mask = augmenter.generate_mask_from_bbox(extended_image, SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME)
|
100 |
+
inverted_mask_image = augmenter.invert_mask(mask)
|
101 |
+
mask.save("mask.jpg")
|
102 |
+
inverted_mask_image.save("inverted_mask.jpg")
|