Spaces:
Running
on
L40S
Running
on
L40S
import os, sys | |
sys.path.insert(0, f"{os.path.dirname(os.path.dirname(os.path.abspath(__file__)))}") | |
import numpy as np | |
from PIL import Image | |
from rembg import remove, new_session | |
from infer.utils import timing_decorator | |
class Removebg(): | |
def __init__(self, name="u2net"): | |
self.session = new_session(name) | |
def __call__(self, rgb_maybe, force=True): | |
''' | |
args: | |
rgb_maybe: PIL.Image, with RGB mode or RGBA mode | |
force: bool, if input is RGBA mode, covert to RGB then remove bg | |
return: | |
rgba_img: PIL.Image, with RGBA mode | |
''' | |
if rgb_maybe.mode == "RGBA": | |
if force: | |
rgb_maybe = rgb_maybe.convert("RGB") | |
rgba_img = remove(rgb_maybe, session=self.session) | |
else: | |
rgba_img = rgb_maybe | |
else: | |
rgba_img = remove(rgb_maybe, session=self.session) | |
rgba_img = white_out_background(rgba_img) | |
rgba_img = preprocess(rgba_img) | |
return rgba_img | |
def white_out_background(pil_img): | |
data = pil_img.getdata() | |
new_data = [] | |
for r, g, b, a in data: | |
if a < 16: # background | |
new_data.append((255, 255, 255, 0)) # full white color | |
else: | |
is_white = (r>235) and (g>235) and (b>235) | |
new_r = 235 if is_white else r | |
new_g = 235 if is_white else g | |
new_b = 235 if is_white else b | |
new_data.append((new_r, new_g, new_b, a)) | |
pil_img.putdata(new_data) | |
return pil_img | |
def preprocess(rgba_img, size=(512,512), ratio=1.15): | |
image = np.asarray(rgba_img) | |
rgb, alpha = image[:,:,:3] / 255., image[:,:,3:] / 255. | |
# crop | |
coords = np.nonzero(alpha > 0.1) | |
x_min, x_max = coords[0].min(), coords[0].max() | |
y_min, y_max = coords[1].min(), coords[1].max() | |
rgb = (rgb[x_min:x_max, y_min:y_max, :] * 255).astype("uint8") | |
alpha = (alpha[x_min:x_max, y_min:y_max, 0] * 255).astype("uint8") | |
# padding | |
h, w = rgb.shape[:2] | |
resize_side = int(max(h, w) * ratio) | |
pad_h, pad_w = resize_side - h, resize_side - w | |
start_h, start_w = pad_h // 2, pad_w // 2 | |
new_rgb = np.ones((resize_side, resize_side, 3), dtype=np.uint8) * 255 | |
new_alpha = np.zeros((resize_side, resize_side), dtype=np.uint8) | |
new_rgb[start_h:start_h + h, start_w:start_w + w] = rgb | |
new_alpha[start_h:start_h + h, start_w:start_w + w] = alpha | |
rgba_array = np.concatenate((new_rgb, new_alpha[:,:,None]), axis=-1) | |
rgba_image = Image.fromarray(rgba_array, 'RGBA') | |
rgba_image = rgba_image.resize(size) | |
return rgba_image | |
if __name__ == "__main__": | |
import argparse | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--rgb_path", type=str, required=True) | |
parser.add_argument("--output_rgba_path", type=str, required=True) | |
parser.add_argument("--force", default=False, action="store_true") | |
return parser.parse_args() | |
args = get_args() | |
rgb_maybe = Image.open(args.rgb_path) | |
model = Removebg() | |
rgba_pil = model(rgb_maybe, args.force) | |
rgba_pil.save(args.output_rgba_path) | |