import os os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama.onnx") os.system("pip install onnxruntime") import cv2 import paddlehub as hub import gradio as gr import torch from PIL import Image, ImageOps import numpy as np import imageio os.mkdir("data") os.mkdir("dataout") model = hub.Module(name='U2Net') import cv2 import numpy as np import onnxruntime import torch from PIL import Image # Source https://github.com/advimman/lama def get_image(image): if isinstance(image, Image.Image): img = np.array(image) elif isinstance(image, np.ndarray): img = image.copy() else: raise Exception("Input image should be either PIL Image or numpy array!") if img.ndim == 3: img = np.transpose(img, (2, 0, 1)) # chw elif img.ndim == 2: img = img[np.newaxis, ...] assert img.ndim == 3 img = img.astype(np.float32) / 255 return img def ceil_modulo(x, mod): if x % mod == 0: return x return (x // mod + 1) * mod def scale_image(img, factor, interpolation=cv2.INTER_AREA): if img.shape[0] == 1: img = img[0] else: img = np.transpose(img, (1, 2, 0)) img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation) if img.ndim == 2: img = img[None, ...] else: img = np.transpose(img, (2, 0, 1)) return img def pad_img_to_modulo(img, mod): channels, height, width = img.shape out_height = ceil_modulo(height, mod) out_width = ceil_modulo(width, mod) return np.pad( img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric", ) def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None): out_image = get_image(image) out_mask = get_image(mask) if scale_factor is not None: out_image = scale_image(out_image, scale_factor) out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST) if pad_out_to_modulo is not None and pad_out_to_modulo > 1: out_image = pad_img_to_modulo(out_image, pad_out_to_modulo) out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo) out_image = torch.from_numpy(out_image).unsqueeze(0).to(device) out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device) out_mask = (out_mask > 0) * 1 return out_image, out_mask def predict(jpg, msk): sess_options = onnxruntime.SessionOptions() model = onnxruntime.InferenceSession('lama.onnx', sess_options=sess_options) image = Image.open(jpg).resize((512, 512)) mask = Image.open(msk).convert("L").resize((512, 512)) image, mask = prepare_img_and_mask(image, mask, 'cpu') # Run the model outputs = model.run(None, {'l_image_': image.numpy().astype(np.float32), 'l_mask_': mask.numpy().astype(np.float32)}) output = outputs[0][0] # Postprocess the outputs output = output.transpose(1, 2, 0) output = output.astype(np.uint8) output = Image.fromarray(output) output.save("/home/user/app/dataout/data_mask.png") def infer(img,option): print(type(img)) print(type(img["image"])) print(type(img["mask"])) imageio.imwrite("./data/data.png", img["image"]) if option == "automatic (U2net)": result = model.Segmentation( images=[cv2.cvtColor(img["image"], cv2.COLOR_RGB2BGR)], paths=None, batch_size=1, input_size=320, output_dir='output', visualization=True) im = Image.fromarray(result[0]['mask']) im.save("./data/data_mask.png") else: imageio.imwrite("./data/data_mask.png", img["mask"]) predict("./data/data.png", "./data/data_mask.png") return "./dataout/data_mask.png","./data/data_mask.png" inputs = [gr.Image(tool="sketch", label="Input",type="numpy"),gr.inputs.Radio(choices=["automatic (U2net)","manual"], type="value", default="manual", label="Masking option")] outputs = [gr.outputs.Image(type="file",label="output"),gr.outputs.Image(type="file",label="Mask")] title = "LaMa Image Inpainting" description = "Gradio demo for LaMa: Resolution-robust Large Mask Inpainting with Fourier Convolutions. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below. Masks are generated by U^2net" article = "
Resolution-robust Large Mask Inpainting with Fourier Convolutions | Github Repo
" gr.Interface(infer, inputs, outputs, title=title, description=description, article=article).launch()