flask_rm_bg-cp / example_inference.py
sammyview80's picture
added color
057eeaa verified
raw
history blame contribute delete
No virus
1.32 kB
from skimage import io
import torch, os
from PIL import Image
from briarmbg import BriaRMBG
from utilities import preprocess_image, postprocess_image
from huggingface_hub import hf_hub_download
import io as IO
import base64
def example_inference(im_path, transprent_bg=False, color=(255, 255, 255, 255)):
net = BriaRMBG()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
net.to(device)
net.eval()
# prepare input
model_input_size = [1024,1024]
orig_im = io.imread(im_path, plugin='imageio')
orig_im_size = orig_im.shape[0:2]
image = preprocess_image(orig_im, model_input_size).to(device)
# inference
result=net(image)
# post process
result_image = postprocess_image(result[0][0], orig_im_size)
bgColor = (0,0,0, 0) if transprent_bg else color
# save result
pil_im = Image.fromarray(result_image)
no_bg_image = Image.new("RGBA", pil_im.size, bgColor)
orig_image = Image.open(IO.BytesIO(im_path))
no_bg_image.paste(orig_image, mask=pil_im)
# Convert image to bytes and then to base64
buffered = IO.BytesIO()
no_bg_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str