Spaces:
Running
Running
File size: 10,805 Bytes
3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 87c57a3 5f57808 3a6f1f2 87c57a3 3a6f1f2 5f57808 3faa99b 3a6f1f2 c8f8b0e 87c57a3 3a6f1f2 c8f8b0e 87c57a3 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 87c57a3 c8f8b0e 87c57a3 3faa99b c8f8b0e 3faa99b 5f57808 c8f8b0e 5f57808 c8f8b0e 5f57808 3a6f1f2 87c57a3 3faa99b c8f8b0e 3faa99b 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 5f57808 3a6f1f2 3faa99b 3a6f1f2 3faa99b 3a6f1f2 87c57a3 3a6f1f2 c8f8b0e 3a6f1f2 c8f8b0e 3a6f1f2 3faa99b 3a6f1f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
import io
from enum import Enum
from typing import Any, List, Optional, Tuple, Union, cast
import numpy as np
import onnxruntime as ort
from cv2 import (
BORDER_DEFAULT,
MORPH_ELLIPSE,
MORPH_OPEN,
GaussianBlur,
getStructuringElement,
morphologyEx,
)
from PIL import Image, ImageOps
from PIL.Image import Image as PILImage
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from pymatting.util.util import stack_images
from scipy.ndimage import binary_erosion
from .session_factory import new_session
from .sessions import sessions_class
from .sessions.base import BaseSession
ort.set_default_logger_severity(3)
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
class ReturnType(Enum):
BYTES = 0
PILLOW = 1
NDARRAY = 2
def alpha_matting_cutout(
img: PILImage,
mask: PILImage,
foreground_threshold: int,
background_threshold: int,
erode_structure_size: int,
) -> PILImage:
"""
Perform alpha matting on an image using a given mask and threshold values.
This function takes a PIL image `img` and a PIL image `mask` as input, along with
the `foreground_threshold` and `background_threshold` values used to determine
foreground and background pixels. The `erode_structure_size` parameter specifies
the size of the erosion structure to be applied to the mask.
The function returns a PIL image representing the cutout of the foreground object
from the original image.
"""
if img.mode == "RGBA" or img.mode == "CMYK":
img = img.convert("RGB")
img_array = np.asarray(img)
mask_array = np.asarray(mask)
is_foreground = mask_array > foreground_threshold
is_background = mask_array < background_threshold
structure = None
if erode_structure_size > 0:
structure = np.ones(
(erode_structure_size, erode_structure_size), dtype=np.uint8
)
is_foreground = binary_erosion(is_foreground, structure=structure)
is_background = binary_erosion(is_background, structure=structure, border_value=1)
trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128)
trimap[is_foreground] = 255
trimap[is_background] = 0
img_normalized = img_array / 255.0
trimap_normalized = trimap / 255.0
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
foreground = estimate_foreground_ml(img_normalized, alpha)
cutout = stack_images(foreground, alpha)
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
cutout = Image.fromarray(cutout)
return cutout
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
"""
Perform a simple cutout operation on an image using a mask.
This function takes a PIL image `img` and a PIL image `mask` as input.
It uses the mask to create a new image where the pixels from `img` are
cut out based on the mask.
The function returns a PIL image representing the cutout of the original
image using the mask.
"""
empty = Image.new("RGBA", (img.size), 0)
cutout = Image.composite(img, empty, mask)
return cutout
def putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage:
"""
Apply the specified mask to the image as an alpha cutout.
Args:
img (PILImage): The image to be modified.
mask (PILImage): The mask to be applied.
Returns:
PILImage: The modified image with the alpha cutout applied.
"""
img.putalpha(mask)
return img
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
"""
Concatenate multiple images vertically.
Args:
imgs (List[PILImage]): The list of images to be concatenated.
Returns:
PILImage: The concatenated image.
"""
pivot = imgs.pop(0)
for im in imgs:
pivot = get_concat_v(pivot, im)
return pivot
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
"""
Concatenate two images vertically.
Args:
img1 (PILImage): The first image.
img2 (PILImage): The second image to be concatenated below the first image.
Returns:
PILImage: The concatenated image.
"""
dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
dst.paste(img1, (0, 0))
dst.paste(img2, (0, img1.height))
return dst
def post_process(mask: np.ndarray) -> np.ndarray:
"""
Post Process the mask for a smooth boundary by applying Morphological Operations
Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757
args:
mask: Binary Numpy Mask
"""
mask = morphologyEx(mask, MORPH_OPEN, kernel)
mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT)
mask = np.where(mask < 127, 0, 255).astype(np.uint8) # type: ignore
return mask
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
"""
Apply the specified background color to the image.
Args:
img (PILImage): The image to be modified.
color (Tuple[int, int, int, int]): The RGBA color to be applied.
Returns:
PILImage: The modified image with the background color applied.
"""
r, g, b, a = color
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
colored_image.paste(img, mask=img)
return colored_image
def fix_image_orientation(img: PILImage) -> PILImage:
"""
Fix the orientation of the image based on its EXIF data.
Args:
img (PILImage): The image to be fixed.
Returns:
PILImage: The fixed image.
"""
return cast(PILImage, ImageOps.exif_transpose(img))
def download_models() -> None:
"""
Download models for image processing.
"""
for session in sessions_class:
session.download_models()
def remove(
data: Union[bytes, PILImage, np.ndarray],
alpha_matting: bool = False,
alpha_matting_foreground_threshold: int = 240,
alpha_matting_background_threshold: int = 10,
alpha_matting_erode_size: int = 10,
session: Optional[BaseSession] = None,
only_mask: bool = False,
post_process_mask: bool = False,
bgcolor: Optional[Tuple[int, int, int, int]] = None,
force_return_bytes: bool = False,
*args: Optional[Any],
**kwargs: Optional[Any]
) -> Union[bytes, PILImage, np.ndarray]:
"""
Remove the background from an input image.
This function takes in various parameters and returns a modified version of the input image with the background removed. The function can handle input data in the form of bytes, a PIL image, or a numpy array. The function first checks the type of the input data and converts it to a PIL image if necessary. It then fixes the orientation of the image and proceeds to perform background removal using the 'u2net' model. The result is a list of binary masks representing the foreground objects in the image. These masks are post-processed and combined to create a final cutout image. If a background color is provided, it is applied to the cutout image. The function returns the resulting cutout image in the format specified by the input 'return_type' parameter or as python bytes if force_return_bytes is true.
Parameters:
data (Union[bytes, PILImage, np.ndarray]): The input image data.
alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False.
alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240.
alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10.
alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10.
session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None.
only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False.
post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False.
bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None.
force_return_bytes (bool, optional): Flag indicating whether to return the cutout image as bytes. Defaults to False.
*args (Optional[Any]): Additional positional arguments.
**kwargs (Optional[Any]): Additional keyword arguments.
Returns:
Union[bytes, PILImage, np.ndarray]: The cutout image with the background removed.
"""
if isinstance(data, bytes) or force_return_bytes:
return_type = ReturnType.BYTES
img = cast(PILImage, Image.open(io.BytesIO(cast(bytes, data))))
elif isinstance(data, PILImage):
return_type = ReturnType.PILLOW
img = cast(PILImage, data)
elif isinstance(data, np.ndarray):
return_type = ReturnType.NDARRAY
img = cast(PILImage, Image.fromarray(data))
else:
raise ValueError(
"Input type {} is not supported. Try using force_return_bytes=True to force python bytes output".format(
type(data)
)
)
putalpha = kwargs.pop("putalpha", False)
# Fix image orientation
img = fix_image_orientation(img)
if session is None:
session = new_session("u2net", *args, **kwargs)
masks = session.predict(img, *args, **kwargs)
cutouts = []
for mask in masks:
if post_process_mask:
mask = Image.fromarray(post_process(np.array(mask)))
if only_mask:
cutout = mask
elif alpha_matting:
try:
cutout = alpha_matting_cutout(
img,
mask,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
)
except ValueError:
if putalpha:
cutout = putalpha_cutout(img, mask)
else:
cutout = naive_cutout(img, mask)
else:
if putalpha:
cutout = putalpha_cutout(img, mask)
else:
cutout = naive_cutout(img, mask)
cutouts.append(cutout)
cutout = img
if len(cutouts) > 0:
cutout = get_concat_v_multi(cutouts)
if bgcolor is not None and not only_mask:
cutout = apply_background_color(cutout, bgcolor)
if ReturnType.PILLOW == return_type:
return cutout
if ReturnType.NDARRAY == return_type:
return np.asarray(cutout)
bio = io.BytesIO()
cutout.save(bio, "PNG")
bio.seek(0)
return bio.read()
|