|
import torch |
|
from ultralytics import YOLO |
|
from transformers import SamModel, SamProcessor |
|
import numpy as np |
|
from PIL import Image, ImageOps |
|
from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME |
|
from diffusers.utils import load_image |
|
import gc |
|
from s3_manager import S3ManagerService |
|
import io |
|
from io import BytesIO |
|
import base64 |
|
import uuid |
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_memory(): |
|
""" |
|
Clears the memory by collecting garbage and emptying the CUDA cache. |
|
|
|
This function is useful when dealing with memory-intensive operations in Python, especially when using libraries like PyTorch. |
|
|
|
""" |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
def accelerator(): |
|
""" |
|
Determines the device accelerator to use based on availability. |
|
|
|
Returns: |
|
str: The name of the device accelerator ('cuda', 'mps', or 'cpu'). |
|
""" |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
class ImageAugmentation: |
|
""" |
|
Class for centering an image on a white background using ROI. |
|
|
|
Attributes: |
|
target_width (int): Desired width of the extended image. |
|
target_height (int): Desired height of the extended image. |
|
roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image. |
|
""" |
|
|
|
def __init__(self, target_width, target_height, roi_scale=0.6): |
|
self.target_width = target_width |
|
self.target_height = target_height |
|
self.roi_scale = roi_scale |
|
|
|
def extend_image(self, image: Image) -> Image: |
|
""" |
|
Extends an image to fit within the specified target dimensions while maintaining the aspect ratio. |
|
""" |
|
original_width, original_height = image.size |
|
scale = min(self.target_width / original_width, self.target_height / original_height) |
|
new_width = int(original_width * scale * self.roi_scale) |
|
new_height = int(original_height * scale * self.roi_scale) |
|
resized_image = image.resize((new_width, new_height)) |
|
extended_image = Image.new("RGB", (self.target_width, self.target_height), "white") |
|
paste_x = (self.target_width - new_width) // 2 |
|
paste_y = (self.target_height - new_height) // 2 |
|
extended_image.paste(resized_image, (paste_x, paste_y)) |
|
return extended_image |
|
|
|
def generate_mask_from_bbox(self,image: Image, segmentation_model: str ,detection_model) -> Image: |
|
""" |
|
Generates a mask from the bounding box of an image using YOLO and SAM-ViT models. |
|
|
|
Args: |
|
image_path (str): The path to the input image. |
|
|
|
Returns: |
|
numpy.ndarray: The generated mask as a NumPy array. |
|
""" |
|
|
|
yolo = YOLO(detection_model) |
|
processor = SamProcessor.from_pretrained(segmentation_model) |
|
model = SamModel.from_pretrained(segmentation_model).to(device=accelerator()) |
|
results = yolo(image) |
|
bboxes = results[0].boxes.xyxy.tolist() |
|
input_boxes = [[[bboxes[0]]]] |
|
inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
mask = processor.image_processor.post_process_masks( |
|
outputs.pred_masks.cpu(), |
|
inputs["original_sizes"].cpu(), |
|
inputs["reshaped_input_sizes"].cpu() |
|
)[0][0][0].numpy() |
|
mask_image = Image.fromarray(mask) |
|
return mask_image |
|
|
|
|
|
|
|
def invert_mask(self, mask_image: np.ndarray) -> np.ndarray: |
|
""" |
|
Inverts the given mask image. |
|
""" |
|
|
|
|
|
inverted_mask_pil = ImageOps.invert(mask_image.convert("L")) |
|
return inverted_mask_pil |
|
|
|
def pil_to_b64_json(image): |
|
""" |
|
Converts a PIL image to a base64-encoded JSON object. |
|
|
|
Args: |
|
image (PIL.Image.Image): The PIL image object to be converted. |
|
|
|
Returns: |
|
dict: A dictionary containing the image ID and the base64-encoded image. |
|
|
|
""" |
|
image_id = str(uuid.uuid4()) |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
return {"image_id": image_id, "b64_image": b64_image} |
|
|
|
|
|
def pil_to_s3_json(image: Image.Image, file_name) -> dict: |
|
""" |
|
Uploads a PIL image to Amazon S3 and returns a JSON object containing the image ID and the signed URL. |
|
|
|
Args: |
|
image (PIL.Image.Image): The PIL image to be uploaded. |
|
file_name (str): The name of the file. |
|
|
|
Returns: |
|
dict: A JSON object containing the image ID and the signed URL. |
|
|
|
""" |
|
image_id = str(uuid.uuid4()) |
|
s3_uploader = S3ManagerService() |
|
image_bytes = io.BytesIO() |
|
image.save(image_bytes, format="PNG") |
|
image_bytes.seek(0) |
|
|
|
unique_file_name = s3_uploader.generate_unique_file_name(file_name) |
|
s3_uploader.upload_file(image_bytes, unique_file_name) |
|
signed_url = s3_uploader.generate_signed_url( |
|
unique_file_name, exp=43200 |
|
) |
|
return {"image_id": image_id, "url": signed_url} |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
augmenter = ImageAugmentation(target_width=1024, target_height=1024, roi_scale=0.5) |
|
image_path = "../sample_data/example3.jpg" |
|
image = Image.open(image_path) |
|
extended_image = augmenter.extend_image(image) |
|
mask = augmenter.generate_mask_from_bbox(extended_image, SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME) |
|
inverted_mask_image = augmenter.invert_mask(mask) |
|
mask.save("mask.jpg") |
|
inverted_mask_image.save("inverted_mask.jpg") |
|
|