nikunjkdtechnoland
init commit some more add files
e041d7d
raw
history blame
No virus
2.21 kB
import os
import cv2
import numpy as np
from loguru import logger
from torch.hub import get_dir
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest, RemoveBGModel
class RemoveBG(BasePlugin):
name = "RemoveBG"
support_gen_mask = True
support_gen_image = True
def __init__(self, model_name):
super().__init__()
self.model_name = model_name
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.environ["U2NET_HOME"] = model_dir
self._init_session(model_name)
def _init_session(self, model_name: str):
if model_name == RemoveBGModel.briaai_rmbg_1_4:
from iopaint.plugins.briarmbg import (
create_briarmbg_session,
briarmbg_process,
)
self.session = create_briarmbg_session()
self.remove = briarmbg_process
else:
from rembg import new_session, remove
self.session = new_session(model_name=model_name)
self.remove = remove
def switch_model(self, new_model_name):
if self.model_name == new_model_name:
return
logger.info(
f"Switching removebg model from {self.model_name} to {new_model_name}"
)
self._init_session(new_model_name)
self.model_name = new_model_name
def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGRA image
output = self.remove(bgr_np_img, session=self.session)
return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
# return BGR image, 255 means foreground, 0 means background
output = self.remove(bgr_np_img, session=self.session, only_mask=True)
return output
def check_dep(self):
try:
import rembg
except ImportError:
return (
"RemoveBG is not installed, please install it first. pip install rembg"
)