import io from enum import Enum from typing import Any, List, Optional, Tuple, Union, cast import numpy as np import onnxruntime as ort from cv2 import ( BORDER_DEFAULT, MORPH_ELLIPSE, MORPH_OPEN, GaussianBlur, getStructuringElement, morphologyEx, ) from PIL import Image, ImageOps from PIL.Image import Image as PILImage from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml from pymatting.util.util import stack_images from scipy.ndimage import binary_erosion from .session_factory import new_session from .sessions import sessions_class from .sessions.base import BaseSession ort.set_default_logger_severity(3) kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) class ReturnType(Enum): BYTES = 0 PILLOW = 1 NDARRAY = 2 def alpha_matting_cutout( img: PILImage, mask: PILImage, foreground_threshold: int, background_threshold: int, erode_structure_size: int, ) -> PILImage: """ Perform alpha matting on an image using a given mask and threshold values. This function takes a PIL image `img` and a PIL image `mask` as input, along with the `foreground_threshold` and `background_threshold` values used to determine foreground and background pixels. The `erode_structure_size` parameter specifies the size of the erosion structure to be applied to the mask. The function returns a PIL image representing the cutout of the foreground object from the original image. """ if img.mode == "RGBA" or img.mode == "CMYK": img = img.convert("RGB") img_array = np.asarray(img) mask_array = np.asarray(mask) is_foreground = mask_array > foreground_threshold is_background = mask_array < background_threshold structure = None if erode_structure_size > 0: structure = np.ones( (erode_structure_size, erode_structure_size), dtype=np.uint8 ) is_foreground = binary_erosion(is_foreground, structure=structure) is_background = binary_erosion(is_background, structure=structure, border_value=1) trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128) trimap[is_foreground] = 255 trimap[is_background] = 0 img_normalized = img_array / 255.0 trimap_normalized = trimap / 255.0 alpha = estimate_alpha_cf(img_normalized, trimap_normalized) foreground = estimate_foreground_ml(img_normalized, alpha) cutout = stack_images(foreground, alpha) cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) cutout = Image.fromarray(cutout) return cutout def naive_cutout(img: PILImage, mask: PILImage) -> PILImage: """ Perform a simple cutout operation on an image using a mask. This function takes a PIL image `img` and a PIL image `mask` as input. It uses the mask to create a new image where the pixels from `img` are cut out based on the mask. The function returns a PIL image representing the cutout of the original image using the mask. """ empty = Image.new("RGBA", (img.size), 0) cutout = Image.composite(img, empty, mask) return cutout def putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage: """ Apply the specified mask to the image as an alpha cutout. Args: img (PILImage): The image to be modified. mask (PILImage): The mask to be applied. Returns: PILImage: The modified image with the alpha cutout applied. """ img.putalpha(mask) return img def get_concat_v_multi(imgs: List[PILImage]) -> PILImage: """ Concatenate multiple images vertically. Args: imgs (List[PILImage]): The list of images to be concatenated. Returns: PILImage: The concatenated image. """ pivot = imgs.pop(0) for im in imgs: pivot = get_concat_v(pivot, im) return pivot def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage: """ Concatenate two images vertically. Args: img1 (PILImage): The first image. img2 (PILImage): The second image to be concatenated below the first image. Returns: PILImage: The concatenated image. """ dst = Image.new("RGBA", (img1.width, img1.height + img2.height)) dst.paste(img1, (0, 0)) dst.paste(img2, (0, img1.height)) return dst def post_process(mask: np.ndarray) -> np.ndarray: """ Post Process the mask for a smooth boundary by applying Morphological Operations Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757 args: mask: Binary Numpy Mask """ mask = morphologyEx(mask, MORPH_OPEN, kernel) mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT) mask = np.where(mask < 127, 0, 255).astype(np.uint8) # type: ignore return mask def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage: """ Apply the specified background color to the image. Args: img (PILImage): The image to be modified. color (Tuple[int, int, int, int]): The RGBA color to be applied. Returns: PILImage: The modified image with the background color applied. """ r, g, b, a = color colored_image = Image.new("RGBA", img.size, (r, g, b, a)) colored_image.paste(img, mask=img) return colored_image def fix_image_orientation(img: PILImage) -> PILImage: """ Fix the orientation of the image based on its EXIF data. Args: img (PILImage): The image to be fixed. Returns: PILImage: The fixed image. """ return cast(PILImage, ImageOps.exif_transpose(img)) def download_models() -> None: """ Download models for image processing. """ for session in sessions_class: session.download_models() def remove( data: Union[bytes, PILImage, np.ndarray], alpha_matting: bool = False, alpha_matting_foreground_threshold: int = 240, alpha_matting_background_threshold: int = 10, alpha_matting_erode_size: int = 10, session: Optional[BaseSession] = None, only_mask: bool = False, post_process_mask: bool = False, bgcolor: Optional[Tuple[int, int, int, int]] = None, force_return_bytes: bool = False, *args: Optional[Any], **kwargs: Optional[Any] ) -> Union[bytes, PILImage, np.ndarray]: """ Remove the background from an input image. This function takes in various parameters and returns a modified version of the input image with the background removed. The function can handle input data in the form of bytes, a PIL image, or a numpy array. The function first checks the type of the input data and converts it to a PIL image if necessary. It then fixes the orientation of the image and proceeds to perform background removal using the 'u2net' model. The result is a list of binary masks representing the foreground objects in the image. These masks are post-processed and combined to create a final cutout image. If a background color is provided, it is applied to the cutout image. The function returns the resulting cutout image in the format specified by the input 'return_type' parameter or as python bytes if force_return_bytes is true. Parameters: data (Union[bytes, PILImage, np.ndarray]): The input image data. alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False. alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240. alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10. alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10. session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None. only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False. post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False. bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None. force_return_bytes (bool, optional): Flag indicating whether to return the cutout image as bytes. Defaults to False. *args (Optional[Any]): Additional positional arguments. **kwargs (Optional[Any]): Additional keyword arguments. Returns: Union[bytes, PILImage, np.ndarray]: The cutout image with the background removed. """ if isinstance(data, bytes) or force_return_bytes: return_type = ReturnType.BYTES img = cast(PILImage, Image.open(io.BytesIO(cast(bytes, data)))) elif isinstance(data, PILImage): return_type = ReturnType.PILLOW img = cast(PILImage, data) elif isinstance(data, np.ndarray): return_type = ReturnType.NDARRAY img = cast(PILImage, Image.fromarray(data)) else: raise ValueError( "Input type {} is not supported. Try using force_return_bytes=True to force python bytes output".format( type(data) ) ) putalpha = kwargs.pop("putalpha", False) # Fix image orientation img = fix_image_orientation(img) if session is None: session = new_session("u2net", *args, **kwargs) masks = session.predict(img, *args, **kwargs) cutouts = [] for mask in masks: if post_process_mask: mask = Image.fromarray(post_process(np.array(mask))) if only_mask: cutout = mask elif alpha_matting: try: cutout = alpha_matting_cutout( img, mask, alpha_matting_foreground_threshold, alpha_matting_background_threshold, alpha_matting_erode_size, ) except ValueError: if putalpha: cutout = putalpha_cutout(img, mask) else: cutout = naive_cutout(img, mask) else: if putalpha: cutout = putalpha_cutout(img, mask) else: cutout = naive_cutout(img, mask) cutouts.append(cutout) cutout = img if len(cutouts) > 0: cutout = get_concat_v_multi(cutouts) if bgcolor is not None and not only_mask: cutout = apply_background_color(cutout, bgcolor) if ReturnType.PILLOW == return_type: return cutout if ReturnType.NDARRAY == return_type: return np.asarray(cutout) bio = io.BytesIO() cutout.save(bio, "PNG") bio.seek(0) return bio.read()