Spaces:
Running
Running
File size: 5,643 Bytes
f3cfe0c 88e9206 f3cfe0c 88a381f 6e67c16 d1a4430 fa32203 6e67c16 4b8ee81 fa32203 d1a4430 f3cfe0c cca63d4 f3cfe0c cca63d4 f3cfe0c cca63d4 f3cfe0c a76141d cca63d4 a76141d f3cfe0c 88a381f cca63d4 a76141d 88a381f a76141d 88a381f a76141d 88a381f d1a4430 88a381f d1a4430 88a381f d1a4430 88a381f d1a4430 88a381f d1a4430 a76141d 88a381f cca63d4 88a381f cca63d4 d1a4430 88a381f 4b8ee81 a76141d 5626570 fa32203 88a381f d1a4430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
from ultralytics import YOLO
from transformers import SamModel, SamProcessor
import numpy as np
from PIL import Image, ImageOps
from scripts.config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
from diffusers.utils import load_image
import gc
from scripts.s3_manager import S3ManagerService
import io
from io import BytesIO
import base64
import uuid
def clear_memory():
"""
Clears the memory by collecting garbage and emptying the CUDA cache.
This function is useful when dealing with memory-intensive operations in Python, especially when using libraries like PyTorch.
"""
gc.collect()
torch.cuda.empty_cache()
def accelerator():
"""
Determines the device accelerator to use based on availability.
Returns:
str: The name of the device accelerator ('cuda', 'mps', or 'cpu').
"""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
class ImageAugmentation:
"""
Class for centering an image on a white background using ROI.
Attributes:
target_width (int): Desired width of the extended image.
target_height (int): Desired height of the extended image.
roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
"""
def __init__(self, target_width, target_height, roi_scale=0.6):
self.target_width = target_width
self.target_height = target_height
self.roi_scale = roi_scale
def extend_image(self, image: Image) -> Image:
"""
Extends an image to fit within the specified target dimensions while maintaining the aspect ratio.
"""
original_width, original_height = image.size
scale = min(self.target_width / original_width, self.target_height / original_height)
new_width = int(original_width * scale * self.roi_scale)
new_height = int(original_height * scale * self.roi_scale)
resized_image = image.resize((new_width, new_height))
extended_image = Image.new("RGB", (self.target_width, self.target_height), "white")
paste_x = (self.target_width - new_width) // 2
paste_y = (self.target_height - new_height) // 2
extended_image.paste(resized_image, (paste_x, paste_y))
return extended_image
def generate_mask_from_bbox(self,image: Image, segmentation_model: str ,detection_model) -> Image:
"""
Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
Args:
image_path (str): The path to the input image.
Returns:
numpy.ndarray: The generated mask as a NumPy array.
"""
yolo = YOLO(detection_model)
processor = SamProcessor.from_pretrained(segmentation_model)
model = SamModel.from_pretrained(segmentation_model).to(device=accelerator())
results = yolo(image)
bboxes = results[0].boxes.xyxy.tolist()
input_boxes = [[[bboxes[0]]]]
inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask_image = Image.fromarray(mask)
return mask_image
def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
"""
Inverts the given mask image.
"""
inverted_mask_pil = ImageOps.invert(mask_image.convert("L"))
return inverted_mask_pil
def pil_to_b64_json(image):
"""
Converts a PIL image to a base64-encoded JSON object.
Args:
image (PIL.Image.Image): The PIL image object to be converted.
Returns:
dict: A dictionary containing the image ID and the base64-encoded image.
"""
image_id = str(uuid.uuid4())
buffered = BytesIO()
image.save(buffered, format="PNG")
b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image_id": image_id, "b64_image": b64_image}
def pil_to_s3_json(image: Image.Image, file_name) -> dict:
"""
Uploads a PIL image to Amazon S3 and returns a JSON object containing the image ID and the signed URL.
Args:
image (PIL.Image.Image): The PIL image to be uploaded.
file_name (str): The name of the file.
Returns:
dict: A JSON object containing the image ID and the signed URL.
"""
image_id = str(uuid.uuid4())
s3_uploader = S3ManagerService()
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes.seek(0)
unique_file_name = s3_uploader.generate_unique_file_name(file_name)
s3_uploader.upload_file(image_bytes, unique_file_name)
signed_url = s3_uploader.generate_signed_url(
unique_file_name, exp=43200
) # 12 hours
return {"image_id": image_id, "url": signed_url}
if __name__ == "__main__":
augmenter = ImageAugmentation(target_width=1024, target_height=1024, roi_scale=0.5)
image_path = "../sample_data/example3.jpg"
image = Image.open(image_path)
extended_image = augmenter.extend_image(image)
mask = augmenter.generate_mask_from_bbox(extended_image, SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME)
inverted_mask_image = augmenter.invert_mask(mask)
mask.save("mask.jpg")
inverted_mask_image.save("inverted_mask.jpg")
|