File size: 5,643 Bytes
f3cfe0c
 
88e9206
f3cfe0c
88a381f
6e67c16
d1a4430
fa32203
6e67c16
4b8ee81
 
 
 
fa32203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1a4430
 
f3cfe0c
 
 
 
 
 
 
 
 
cca63d4
f3cfe0c
cca63d4
f3cfe0c
cca63d4
f3cfe0c
a76141d
 
 
 
 
cca63d4
 
 
a76141d
f3cfe0c
88a381f
cca63d4
 
 
a76141d
88a381f
a76141d
88a381f
a76141d
88a381f
 
 
 
 
 
 
 
 
 
 
d1a4430
88a381f
 
d1a4430
 
 
 
 
 
88a381f
d1a4430
 
 
 
 
88a381f
d1a4430
 
88a381f
 
d1a4430
 
 
 
 
 
 
 
 
a76141d
88a381f
cca63d4
88a381f
cca63d4
 
d1a4430
 
88a381f
4b8ee81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a76141d
 
5626570
fa32203
88a381f
 
d1a4430
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
from ultralytics import YOLO
from transformers import SamModel, SamProcessor
import numpy as np
from PIL import Image, ImageOps
from scripts.config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
from diffusers.utils import load_image
import gc
from scripts.s3_manager import S3ManagerService
import io
from io import BytesIO
import base64
import uuid






def clear_memory():
    """
    Clears the memory by collecting garbage and emptying the CUDA cache.

    This function is useful when dealing with memory-intensive operations in Python, especially when using libraries like PyTorch.

   """
    gc.collect()
    torch.cuda.empty_cache()
   




def accelerator():
    """
    Determines the device accelerator to use based on availability.

    Returns:
        str: The name of the device accelerator ('cuda', 'mps', or 'cpu').
    """
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

class ImageAugmentation:
    """
    Class for centering an image on a white background using ROI.

    Attributes:
        target_width (int): Desired width of the extended image.
        target_height (int): Desired height of the extended image.
        roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
    """

    def __init__(self, target_width, target_height, roi_scale=0.6):
        self.target_width = target_width
        self.target_height = target_height
        self.roi_scale = roi_scale

    def extend_image(self, image: Image) -> Image:
        """
        Extends an image to fit within the specified target dimensions while maintaining the aspect ratio.
        """
        original_width, original_height = image.size
        scale = min(self.target_width / original_width, self.target_height / original_height)
        new_width = int(original_width * scale * self.roi_scale)
        new_height = int(original_height * scale * self.roi_scale)
        resized_image = image.resize((new_width, new_height))
        extended_image = Image.new("RGB", (self.target_width, self.target_height), "white")
        paste_x = (self.target_width - new_width) // 2
        paste_y = (self.target_height - new_height) // 2
        extended_image.paste(resized_image, (paste_x, paste_y))
        return extended_image

    def generate_mask_from_bbox(self,image: Image, segmentation_model: str ,detection_model) -> Image:
        """
        Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.

        Args:
            image_path (str): The path to the input image.

        Returns:
            numpy.ndarray: The generated mask as a NumPy array.
        """
    
        yolo = YOLO(detection_model)
        processor = SamProcessor.from_pretrained(segmentation_model)
        model = SamModel.from_pretrained(segmentation_model).to(device=accelerator())
        results = yolo(image)
        bboxes = results[0].boxes.xyxy.tolist()
        input_boxes = [[[bboxes[0]]]]
        inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = model(**inputs)
        mask = processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )[0][0][0].numpy()
        mask_image = Image.fromarray(mask)
        return mask_image



    def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
        """
        Inverts the given mask image.
        """
        
        
        inverted_mask_pil = ImageOps.invert(mask_image.convert("L"))
        return inverted_mask_pil
    
def pil_to_b64_json(image):
    """
    Converts a PIL image to a base64-encoded JSON object.

    Args:
        image (PIL.Image.Image): The PIL image object to be converted.

    Returns:
        dict: A dictionary containing the image ID and the base64-encoded image.

    """
    image_id = str(uuid.uuid4())
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return {"image_id": image_id, "b64_image": b64_image}


def pil_to_s3_json(image: Image.Image, file_name) -> dict:
    """
    Uploads a PIL image to Amazon S3 and returns a JSON object containing the image ID and the signed URL.

    Args:
        image (PIL.Image.Image): The PIL image to be uploaded.
        file_name (str): The name of the file.

    Returns:
        dict: A JSON object containing the image ID and the signed URL.

    """
    image_id = str(uuid.uuid4())
    s3_uploader = S3ManagerService()
    image_bytes = io.BytesIO()
    image.save(image_bytes, format="PNG")
    image_bytes.seek(0)

    unique_file_name = s3_uploader.generate_unique_file_name(file_name)
    s3_uploader.upload_file(image_bytes, unique_file_name)
    signed_url = s3_uploader.generate_signed_url(
        unique_file_name, exp=43200
    )  # 12 hours
    return {"image_id": image_id, "url": signed_url}




if __name__ == "__main__":
    augmenter = ImageAugmentation(target_width=1024, target_height=1024, roi_scale=0.5)
    image_path = "../sample_data/example3.jpg"
    image = Image.open(image_path)
    extended_image = augmenter.extend_image(image)
    mask = augmenter.generate_mask_from_bbox(extended_image, SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME)
    inverted_mask_image = augmenter.invert_mask(mask)
    mask.save("mask.jpg")
    inverted_mask_image.save("inverted_mask.jpg")