import numpy as np import torch import torch.nn.functional as F import gradio as gr from ormbg import ORMBG from PIL import Image def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor: if len(im.shape) < 3: im = im[:, :, np.newaxis] im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) im_tensor = F.interpolate( torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear" ).type(torch.uint8) image = torch.divide(im_tensor, 255.0) return image def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray: result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi) im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) im_array = np.squeeze(im_array) return im_array def inference(orig_image): model_path = "ormbg.pth" net = ORMBG() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): net.load_state_dict(torch.load(model_path)) net = net.cuda() else: net.load_state_dict(torch.load(model_path, map_location="cpu")) net.eval() model_input_size = [1024, 1024] orig_im_size = orig_image.shape[0:2] image = preprocess_image(orig_image, model_input_size).to(device) result = net(image) # post process result_image = postprocess_image(result[0][0], orig_im_size) # save result pil_im = Image.fromarray(result_image) no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) no_bg_image.paste(orig_image, mask=pil_im) return no_bg_image gr.Markdown("## Open Remove Background Model (ormbg)") gr.HTML( """
This is a demo for Open Remove Background Model (ormbg) that using Open Remove Background Model (ormbg) model as backbone.
""" ) title = "Background Removal" description = r""" This model is a fully open-source background remover optimized for images with humans. It is based on Highly Accurate Dichotomous Image Segmentation research. You can find more about the model here. """ examples = [ ["./input.png"], ] demo = gr.Interface( fn=inference, inputs="image", outputs="image", examples=examples, title=title, description=description, ) if __name__ == "__main__": demo.launch(share=False)