File size: 5,198 Bytes
19b3da3 cd51d32 19b3da3 35575bb a3d6c18 35575bb 19b3da3 35575bb 19b3da3 42ef134 a3d6c18 35575bb f1235a4 19b3da3 a3d6c18 10230ea a3d6c18 10230ea a3d6c18 f1235a4 a3d6c18 10230ea 42ef134 10230ea 35575bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import io
from pathlib import Path
from typing import Union
import cv2
import huggingface_hub
import numpy as np
import onnxruntime as rt
import torch
import torch.nn.functional as F
from briarmbg import BriaRMBG # pyright: ignore
from PIL import Image
from rembg import remove
from torchvision.transforms.functional import normalize
import internals.util.image as ImageUtil
from carvekit.api.high import HiInterface
from internals.data.task import ModelType
from internals.util.commons import download_image, read_url
class RemoveBackground:
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
if type(image) is str:
image = Image.open(io.BytesIO(read_url(image)))
output = remove(image)
return output
class RemoveBackgroundV2:
def __init__(self):
model_path = huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.onnx")
self.anime_rembg = rt.InferenceSession(
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.interface = HiInterface(
object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device="cuda" if torch.cuda.is_available() else "cpu",
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False,
)
def remove(
self, image: Union[str, Image.Image], model_type: ModelType = ModelType.REAL
) -> Image.Image:
if type(image) is str:
image = download_image(image)
if model_type == ModelType.ANIME or model_type == ModelType.COMIC:
print("Using Anime Background remover")
_, img = self.__rmbg_fn(np.array(image))
return Image.fromarray(img)
else:
print("Using Real Background remover")
img_path = Path.home() / ".cache" / "rm_bg.png"
w, h = image.size
if max(w, h) > 1536:
image = ImageUtil.resize_image(image, dimension=1024)
image.save(img_path)
images_without_background = self.interface([img_path])
out = images_without_background[0]
return out
def __get_mask(self, img, s=1024):
img = (img / 255).astype(np.float32)
h, w = h0, w0 = img.shape[:-1]
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
ph, pw = s - h, s - w
img_input = np.zeros([s, s, 3], dtype=np.float32)
img_input[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w] = cv2.resize(
img, (w, h)
)
img_input = np.transpose(img_input, (2, 0, 1))
img_input = img_input[np.newaxis, :]
mask = self.anime_rembg.run(None, {"img": img_input})[0][0]
mask = np.transpose(mask, (1, 2, 0))
mask = mask[ph // 2 : ph // 2 + h, pw // 2 : pw // 2 + w]
mask = cv2.resize(mask, (w0, h0))[:, :, np.newaxis]
return mask
def __rmbg_fn(self, img):
mask = self.__get_mask(img)
img = (mask * img + 255 * (1 - mask)).astype(np.uint8)
mask = (mask * 255).astype(np.uint8)
img = np.concatenate([img, mask], axis=2, dtype=np.uint8)
mask = mask.repeat(3, axis=2)
return mask, img
class RemoveBackgroundV3:
def __init__(self):
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
self.net = net
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
if type(image) is str:
image = download_image(image, mode="RGBA")
orig_image = image
w, h = orig_im_size = orig_image.size
image = self.__resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = self.net(im_tensor)
# post process
result = torch.squeeze(
F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0
)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
# new_orig_image = orig_image.convert('RGBA')
return new_im
def __resize_image(self, image):
image = image.convert("RGB")
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
|