picpilot-server / scripts /api_utils.py
VikramSingh178's picture
chore: Update inpainting pipeline configuration and parameters
7ab9afa
import torch
from ultralytics import YOLO
from transformers import SamModel, SamProcessor
import numpy as np
from PIL import Image, ImageOps
from config import SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME
from diffusers.utils import load_image
import gc
from s3_manager import S3ManagerService
import io
from io import BytesIO
import base64
import uuid
def clear_memory():
"""
Clears the memory by collecting garbage and emptying the CUDA cache.
This function is useful when dealing with memory-intensive operations in Python, especially when using libraries like PyTorch.
"""
gc.collect()
torch.cuda.empty_cache()
def accelerator():
"""
Determines the device accelerator to use based on availability.
Returns:
str: The name of the device accelerator ('cuda', 'mps', or 'cpu').
"""
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
class ImageAugmentation:
"""
Class for centering an image on a white background using ROI.
Attributes:
target_width (int): Desired width of the extended image.
target_height (int): Desired height of the extended image.
roi_scale (float): Scale factor to determine the size of the region of interest (ROI) in the original image.
"""
def __init__(self, target_width, target_height, roi_scale=0.6):
self.target_width = target_width
self.target_height = target_height
self.roi_scale = roi_scale
def extend_image(self, image: Image) -> Image:
"""
Extends an image to fit within the specified target dimensions while maintaining the aspect ratio.
"""
original_width, original_height = image.size
scale = min(self.target_width / original_width, self.target_height / original_height)
new_width = int(original_width * scale * self.roi_scale)
new_height = int(original_height * scale * self.roi_scale)
resized_image = image.resize((new_width, new_height))
extended_image = Image.new("RGB", (self.target_width, self.target_height), "white")
paste_x = (self.target_width - new_width) // 2
paste_y = (self.target_height - new_height) // 2
extended_image.paste(resized_image, (paste_x, paste_y))
return extended_image
def generate_mask_from_bbox(self,image: Image, segmentation_model: str ,detection_model) -> Image:
"""
Generates a mask from the bounding box of an image using YOLO and SAM-ViT models.
Args:
image_path (str): The path to the input image.
Returns:
numpy.ndarray: The generated mask as a NumPy array.
"""
yolo = YOLO(detection_model)
processor = SamProcessor.from_pretrained(segmentation_model)
model = SamModel.from_pretrained(segmentation_model).to(device=accelerator())
results = yolo(image)
bboxes = results[0].boxes.xyxy.tolist()
input_boxes = [[[bboxes[0]]]]
inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
mask = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask_image = Image.fromarray(mask)
return mask_image
def invert_mask(self, mask_image: np.ndarray) -> np.ndarray:
"""
Inverts the given mask image.
"""
inverted_mask_pil = ImageOps.invert(mask_image.convert("L"))
return inverted_mask_pil
def pil_to_b64_json(image):
"""
Converts a PIL image to a base64-encoded JSON object.
Args:
image (PIL.Image.Image): The PIL image object to be converted.
Returns:
dict: A dictionary containing the image ID and the base64-encoded image.
"""
image_id = str(uuid.uuid4())
buffered = BytesIO()
image.save(buffered, format="PNG")
b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image_id": image_id, "b64_image": b64_image}
def pil_to_s3_json(image: Image.Image, file_name) -> dict:
"""
Uploads a PIL image to Amazon S3 and returns a JSON object containing the image ID and the signed URL.
Args:
image (PIL.Image.Image): The PIL image to be uploaded.
file_name (str): The name of the file.
Returns:
dict: A JSON object containing the image ID and the signed URL.
"""
image_id = str(uuid.uuid4())
s3_uploader = S3ManagerService()
image_bytes = io.BytesIO()
image.save(image_bytes, format="PNG")
image_bytes.seek(0)
unique_file_name = s3_uploader.generate_unique_file_name(file_name)
s3_uploader.upload_file(image_bytes, unique_file_name)
signed_url = s3_uploader.generate_signed_url(
unique_file_name, exp=43200
) # 12 hours
return {"image_id": image_id, "url": signed_url}
if __name__ == "__main__":
augmenter = ImageAugmentation(target_width=1024, target_height=1024, roi_scale=0.5)
image_path = "../sample_data/example3.jpg"
image = Image.open(image_path)
extended_image = augmenter.extend_image(image)
mask = augmenter.generate_mask_from_bbox(extended_image, SEGMENTATION_MODEL_NAME, DETECTION_MODEL_NAME)
inverted_mask_image = augmenter.invert_mask(mask)
mask.save("mask.jpg")
inverted_mask_image.save("inverted_mask.jpg")