VikramSingh178 commited on
Commit
88a381f
1 Parent(s): cca63d4

feat: Add YOLOv8s object detection model

Browse files

Former-commit-id: 7181c5ce6b26b943a24ccc366026d4a88461c241

scripts/__pycache__/config.cpython-310.pyc CHANGED
Binary files a/scripts/__pycache__/config.cpython-310.pyc and b/scripts/__pycache__/config.cpython-310.pyc differ
 
scripts/config.py CHANGED
@@ -7,6 +7,7 @@ PROJECT_NAME = "Product Photography"
7
  PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
8
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
  SEGMENTATION_MODEL_NAME = "facebook/sam-vit-huge"
 
10
 
11
 
12
 
 
7
  PRODUCTS_10k_DATASET = "VikramSingh178/Products-10k-BLIP-captions"
8
  CAPTIONING_MODEL_NAME = "Salesforce/blip-image-captioning-base"
9
  SEGMENTATION_MODEL_NAME = "facebook/sam-vit-huge"
10
+ DETECTION_MODEL_NAME = "yolov8s"
11
 
12
 
13
 
scripts/extended_image.png ADDED
scripts/mask.png ADDED
scripts/utils.py CHANGED
@@ -2,10 +2,8 @@ import torch
2
  from ultralytics import YOLO
3
  from transformers import SamModel, SamProcessor
4
  import numpy as np
5
- from PIL import Image
6
- from config import SEGMENTATION_MODEL_NAME
7
- import cv2
8
- import matplotlib.pyplot as plt
9
 
10
  def accelerator():
11
  """
@@ -21,7 +19,6 @@ def accelerator():
21
  else:
22
  return "cpu"
23
 
24
-
25
  class ImageAugmentation:
26
  """
27
  Class for centering an image on a white background using ROI.
@@ -32,119 +29,63 @@ class ImageAugmentation:
32
  roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
33
  """
34
 
35
- def __init__(self, target_width, target_height, roi_scale=0.5):
36
- """
37
- Initialize ImageAugmentation class.
38
-
39
- Args:
40
- target_width (int): Desired width of the extended image.
41
- target_height (int): Desired height of the extended image.
42
- roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
43
- """
44
  self.target_width = target_width
45
  self.target_height = target_height
46
  self.roi_scale = roi_scale
47
 
48
- def extend_image(self, image_path):
49
  """
50
- Extends the given image to the specified target dimensions while maintaining the aspect ratio of the original image.
51
- The image is centered based on the detected region of interest (ROI).
52
-
53
- Args:
54
- image_path (str): The path to the image file.
55
-
56
- Returns:
57
- PIL.Image.Image: The extended image with the specified dimensions.
58
  """
59
- # Open the original image
60
- original_image = cv2.imread(image_path)
61
-
62
- # Convert the image to grayscale for better edge detection
63
- gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
64
-
65
- # Perform edge detection to find contours
66
- edges = cv2.Canny(gray_image, 50, 150)
67
- contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
-
69
- # Find the largest contour (assumed to be the ROI)
70
- largest_contour = max(contours, key=cv2.contourArea)
71
-
72
- # Get the bounding box of the largest contour
73
- x, y, w, h = cv2.boundingRect(largest_contour)
74
-
75
- # Calculate the center of the bounding box
76
- roi_center_x = x + w // 2
77
- roi_center_y = y + h // 2
78
-
79
- # Calculate the top-left coordinates of the ROI
80
- roi_x = max(0, roi_center_x - self.target_width // 2)
81
- roi_y = max(0, roi_center_y - self.target_height // 2)
82
 
83
- # Crop the ROI from the original image
84
- roi = original_image[roi_y:roi_y+self.target_height, roi_x:roi_x+self.target_width]
 
 
85
 
86
- # Create a new white background image with the target dimensions
87
- extended_image = np.ones((self.target_height, self.target_width, 3), dtype=np.uint8) * 255
88
 
89
- # Calculate the paste position for centering the ROI
90
- paste_x = (self.target_width - roi.shape[1]) // 2
91
- paste_y = (self.target_height - roi.shape[0]) // 2
 
 
92
 
93
- # Paste the ROI onto the white background
94
- extended_image[paste_y:paste_y+roi.shape[0], paste_x:paste_x+roi.shape[1]] = roi
95
 
96
- return Image.fromarray(cv2.cvtColor(extended_image, cv2.COLOR_BGR2RGB))
97
-
98
-
99
- def generate_bbox(self, image):
100
- """
101
- Generate bounding box for the input image.
102
-
103
- Args:
104
- image: The input image.
105
-
106
- Returns:
107
- list: Bounding box coordinates [x_min, y_min, x_max, y_max].
108
- """
109
- model = YOLO("yolov8s.pt")
110
- results = model(image)
111
- bbox = results[0].boxes.xyxy.tolist()
112
- return bbox
113
 
114
- def generate_mask(self, image, bbox):
115
  """
116
- Generates masks for the given image using a segmentation model.
117
-
118
- Args:
119
- image: The input image for which masks need to be generated.
120
- bbox: Bounding box coordinates [x_min, y_min, x_max, y_max].
121
-
122
- Returns:
123
- numpy.ndarray: The generated mask.
124
  """
125
- model = SamModel.from_pretrained(SEGMENTATION_MODEL_NAME).to(device=accelerator())
126
- processor = SamProcessor.from_pretrained(SEGMENTATION_MODEL_NAME)
127
-
128
- # Ensure bbox is in the correct format
129
- bbox_list = [bbox] # Convert bbox to list of lists
130
-
131
- # Pass bbox as a list of lists to SamProcessor
132
- inputs = processor(image, input_boxes=bbox_list, return_tensors="pt").to(device=accelerator())
133
- with torch.no_grad():
134
- outputs = model(**inputs)
135
- masks = processor.image_processor.post_process_masks(
136
- outputs.pred_masks,
137
- inputs["original_sizes"],
138
- inputs["reshaped_input_sizes"],
139
- )
140
 
141
- return masks[0].cpu().numpy()
 
142
 
143
  if __name__ == "__main__":
144
- augmenter = ImageAugmentation(target_width=1920, target_height=1080, roi_scale=0.3)
145
  image_path = "/home/product_diffusion_api/sample_data/example1.jpg"
146
- extended_image = augmenter.extend_image(image_path)
147
- bbox = augmenter.generate_bbox(extended_image)
148
- mask = augmenter.generate_mask(extended_image, bbox)
149
- plt.imsave('mask.jpg', mask)
150
- #Image.fromarray(mask).save("centered_image_with_mask.jpg")
 
2
  from ultralytics import YOLO
3
  from transformers import SamModel, SamProcessor
4
  import numpy as np
5
+ from PIL import Image, ImageOps
6
+ from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
 
 
7
 
8
  def accelerator():
9
  """
 
19
  else:
20
  return "cpu"
21
 
 
22
  class ImageAugmentation:
23
  """
24
  Class for centering an image on a white background using ROI.
 
29
  roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
30
  """
31
 
32
+ def __init__(self, target_width, target_height, roi_scale=0.6):
 
 
 
 
 
 
 
 
33
  self.target_width = target_width
34
  self.target_height = target_height
35
  self.roi_scale = roi_scale
36
 
37
+ def extend_image(self, image: Image) -> Image:
38
  """
39
+ Extends an image to fit within the specified target dimensions while maintaining the aspect ratio.
 
 
 
 
 
 
 
40
  """
41
+ original_width, original_height = image.size
42
+ scale = min(self.target_width / original_width, self.target_height / original_height)
43
+ new_width = int(original_width * scale * self.roi_scale)
44
+ new_height = int(original_height * scale * self.roi_scale)
45
+ resized_image = image.resize((new_width, new_height))
46
+ extended_image = Image.new("RGB", (self.target_width, self.target_height), "white")
47
+ paste_x = (self.target_width - new_width) // 2
48
+ paste_y = (self.target_height - new_height) // 2
49
+ extended_image.paste(resized_image, (paste_x, paste_y))
50
+ return extended_image
51
+
52
+ def generate_mask_from_bbox(self, image: Image) -> np.ndarray:
53
+ """
54
+ Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
55
+ """
56
+ yolo = YOLO(DETECTION_MODEL_NAME)
57
+ processor = SamProcessor.from_pretrained(SEGMENTATION_MODEL_NAME)
58
+ model = SamModel.from_pretrained(SEGMENTATION_MODEL_NAME).to(accelerator())
 
 
 
 
 
59
 
60
+ # Run YOLO detection
61
+ results = yolo(np.array(image))
62
+ bboxes = results[0].boxes.xyxy.tolist()
63
+ print(bboxes)
64
 
 
 
65
 
66
+ # Prepare inputs for SAM
67
+ inputs = processor(image, input_boxes=[bboxes], return_tensors="pt").to(device=accelerator())
68
+ with torch.no_grad():
69
+ outputs = model(**inputs)
70
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
71
 
 
 
72
 
73
+ return masks[0].numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
76
  """
77
+ Inverts the given mask image.
 
 
 
 
 
 
 
78
  """
79
+ mask_image = (mask_image * 255).astype(np.uint8)
80
+ mask_pil = Image.fromarray(mask_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ inverted_mask_pil = ImageOps.invert(mask_pil.convert("L"))
83
+ return inverted_mask_pil
84
 
85
  if __name__ == "__main__":
86
+ augmenter = ImageAugmentation(target_width=1920, target_height=1080, roi_scale=0.6)
87
  image_path = "/home/product_diffusion_api/sample_data/example1.jpg"
88
+ image = Image.open(image_path)
89
+ extended_image = augmenter.extend_image(image)
90
+ mask = augmenter.generate_mask_from_bbox(extended_image)
91
+