prodia2 / cutter.py
Dagfinn1962's picture
Duplicate from pikto/prodia
e774b98
raw
history blame contribute delete
No virus
3.17 kB
import PIL
import numpy as np
from PIL import Image, ImageColor, ImageDraw
from PIL.Image import Image as PILImage
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
from pymatting.util.util import stack_images
from rembg.bg import post_process, naive_cutout, apply_background_color
from scipy.ndimage import binary_erosion
def alpha_matting_cutout(img: PILImage, trimap: np.ndarray) -> PILImage:
if img.mode == "RGBA" or img.mode == "CMYK":
img = img.convert("RGB")
img = np.asarray(img)
img_normalized = img / 255.0
trimap_normalized = trimap / 255.0
alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
foreground = estimate_foreground_ml(img_normalized, alpha)
cutout = stack_images(foreground, alpha)
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
return Image.fromarray(cutout)
def generate_trimap(
mask: PILImage,
foreground_threshold: int,
background_threshold: int,
erode_structure_size: int,
) -> np.ndarray:
mask = np.asarray(mask)
is_foreground = mask > foreground_threshold
is_background = mask < background_threshold
structure = None
if erode_structure_size > 0:
structure = np.ones(
(erode_structure_size, erode_structure_size), dtype=np.uint8
)
is_foreground = binary_erosion(is_foreground, structure=structure)
is_background = binary_erosion(is_background, structure=structure, border_value=1)
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
trimap[is_foreground] = 255
trimap[is_background] = 0
return trimap
def get_background_dominant_color(img: PILImage, mask: PILImage) -> tuple:
negative_img = img.copy()
negative_mask = PIL.ImageOps.invert(mask)
negative_img.putalpha(negative_mask)
negative_img = negative_img.resize((1, 1))
r, g, b, a = negative_img.getpixel((0, 0))
return r, g, b, 255
def remove(session, img: PILImage, smoot: bool, matting: tuple, color) -> (PILImage, PILImage):
mask = session.predict(img)[0]
if smoot:
mask = PIL.Image.fromarray(post_process(np.array(mask)))
fg_t, bg_t, erode = matting
if fg_t > 0 or bg_t > 0 or erode > 0:
mask = generate_trimap(mask, *matting)
try:
cutout = alpha_matting_cutout(img, mask)
mask = PIL.Image.fromarray(mask)
except ValueError as err:
raise err
else:
cutout = naive_cutout(img, mask)
if color is True:
color = get_background_dominant_color(img, mask)
cutout = apply_background_color(cutout, color)
elif isinstance(color, str):
r, g, b = ImageColor.getcolor(color, "RGB")
cutout = apply_background_color(cutout, (r, g, b, 255))
return cutout, mask
def make_label(text, width=600, height=200, color="black") -> PILImage:
image = Image.new("RGB", (width, height), color)
draw = ImageDraw.Draw(image)
text_width, text_height = draw.textsize(text)
draw.text(((width-text_width)/2, height/2), text)
return image