VikramSingh178 commited on
Commit
d1a4430
1 Parent(s): 88a381f

chore: Update segmentation model to facebook/sam-vit-large

Browse files
Files changed (2) hide show
  1. scripts/config.py +1 -1
  2. scripts/utils.py +34 -23
scripts/config.py CHANGED
@@ -6,7 +6,7 @@ DATASET_NAME= "hahminlew/kream-product-blip-captions"
6
  PROJECT_NAME = "Product Photography"
7
  PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
8
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
- SEGMENTATION_MODEL_NAME = "facebook/sam-vit-huge"
10
  DETECTION_MODEL_NAME = "yolov8s"
11
 
12
 
 
6
  PROJECT_NAME = "Product Photography"
7
  PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
8
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
+ SEGMENTATION_MODEL_NAME = "facebook/sam-vit-large"
10
  DETECTION_MODEL_NAME = "yolov8s"
11
 
12
 
scripts/utils.py CHANGED
@@ -4,6 +4,9 @@ from transformers import SamModel, SamProcessor
4
  import numpy as np
5
  from PIL import Image, ImageOps
6
  from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
 
 
 
7
 
8
  def accelerator():
9
  """
@@ -49,43 +52,51 @@ class ImageAugmentation:
49
  extended_image.paste(resized_image, (paste_x, paste_y))
50
  return extended_image
51
 
52
- def generate_mask_from_bbox(self, image: Image) -> np.ndarray:
53
  """
54
  Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
 
 
 
 
 
 
55
  """
56
- yolo = YOLO(DETECTION_MODEL_NAME)
57
- processor = SamProcessor.from_pretrained(SEGMENTATION_MODEL_NAME)
58
- model = SamModel.from_pretrained(SEGMENTATION_MODEL_NAME).to(accelerator())
59
-
60
- # Run YOLO detection
61
- results = yolo(np.array(image))
62
  bboxes = results[0].boxes.xyxy.tolist()
63
- print(bboxes)
64
-
65
-
66
- # Prepare inputs for SAM
67
- inputs = processor(image, input_boxes=[bboxes], return_tensors="pt").to(device=accelerator())
68
  with torch.no_grad():
69
  outputs = model(**inputs)
70
- masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
71
-
72
-
73
- return masks[0].numpy()
 
 
 
 
 
74
 
75
  def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
76
  """
77
  Inverts the given mask image.
78
  """
79
- mask_image = (mask_image * 255).astype(np.uint8)
80
- mask_pil = Image.fromarray(mask_image)
81
 
82
- inverted_mask_pil = ImageOps.invert(mask_pil.convert("L"))
 
83
  return inverted_mask_pil
84
 
85
  if __name__ == "__main__":
86
- augmenter = ImageAugmentation(target_width=1920, target_height=1080, roi_scale=0.6)
87
- image_path = "/home/product_diffusion_api/sample_data/example1.jpg"
88
  image = Image.open(image_path)
89
  extended_image = augmenter.extend_image(image)
90
- mask = augmenter.generate_mask_from_bbox(extended_image)
91
-
 
 
 
4
  import numpy as np
5
  from PIL import Image, ImageOps
6
  from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
7
+ from diffusers.utils import load_image
8
+
9
+
10
 
11
  def accelerator():
12
  """
 
52
  extended_image.paste(resized_image, (paste_x, paste_y))
53
  return extended_image
54
 
55
+ def generate_mask_from_bbox(self,image: Image, segmentation_model: str ,detection_model) -> Image:
56
  """
57
  Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
58
+
59
+ Args:
60
+ image_path (str): The path to the input image.
61
+
62
+ Returns:
63
+ numpy.ndarray: The generated mask as a NumPy array.
64
  """
65
+
66
+ yolo = YOLO(detection_model)
67
+ processor = SamProcessor.from_pretrained(segmentation_model)
68
+ model = SamModel.from_pretrained(segmentation_model).to(device=accelerator())
69
+ results = yolo(image)
 
70
  bboxes = results[0].boxes.xyxy.tolist()
71
+ input_boxes = [[[bboxes[0]]]]
72
+ inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
 
 
 
73
  with torch.no_grad():
74
  outputs = model(**inputs)
75
+ mask = processor.image_processor.post_process_masks(
76
+ outputs.pred_masks.cpu(),
77
+ inputs["original_sizes"].cpu(),
78
+ inputs["reshaped_input_sizes"].cpu()
79
+ )[0][0][0].numpy()
80
+ mask_image = Image.fromarray(mask)
81
+ return mask_image
82
+
83
+
84
 
85
  def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
86
  """
87
  Inverts the given mask image.
88
  """
 
 
89
 
90
+
91
+ inverted_mask_pil = ImageOps.invert(mask_image.convert("L"))
92
  return inverted_mask_pil
93
 
94
  if __name__ == "__main__":
95
+ augmenter = ImageAugmentation(target_width=2560, target_height=1440, roi_scale=0.7)
96
+ image_path = "/home/product_diffusion_api/sample_data/example3.jpg"
97
  image = Image.open(image_path)
98
  extended_image = augmenter.extend_image(image)
99
+ mask = augmenter.generate_mask_from_bbox(extended_image, SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME)
100
+ inverted_mask_image = augmenter.invert_mask(mask)
101
+ mask.save("mask.jpg")
102
+ inverted_mask_image.save("inverted_mask.jpg")