Spaces:
Running
Running
import os | |
import imageio | |
from PIL import Image | |
import gradio as gr | |
import cv2 | |
import paddlehub as hub | |
import onnxruntime | |
# Download and setup models | |
os.system("wget https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx") | |
os.system("pip install onnxruntime imageio") | |
os.makedirs("data", exist_ok=True) | |
os.makedirs("dataout", exist_ok=True) | |
# Load LaMa ONNX model | |
sess_options = onnxruntime.SessionOptions() | |
lama_model = onnxruntime.InferenceSession('lama_fp32.onnx', sess_options=sess_options) | |
# Load U^2-Net model for automatic masking | |
u2net_model = hub.Module(name='U2Net') | |
# --- Helper Functions --- | |
def prepare_image(image, target_size=(512, 512)): | |
"""Resizes and preprocesses image for LaMa model.""" | |
if isinstance(image, Image.Image): | |
image = image.resize(target_size) | |
image = np.array(image) | |
elif isinstance(image, np.ndarray): | |
image = cv2.resize(image, target_size) | |
else: | |
raise ValueError("Input image should be either PIL Image or numpy array!") | |
# Normalize to [0, 1] and convert to CHW format | |
image = image.astype(np.float32) / 255.0 | |
if image.ndim == 3: | |
image = np.transpose(image, (2, 0, 1)) | |
elif image.ndim == 2: | |
image = image[np.newaxis, ...] | |
return image[np.newaxis, ...] # Add batch dimension | |
def generate_mask(image, method="automatic"): | |
"""Generates mask from image using U^2-Net or user input.""" | |
if method == "automatic": | |
input_size = 320 # Adjust based on U^2-Net requirements | |
result = u2net_model.Segmentation( | |
images=[cv2.cvtColor(image, cv2.COLOR_RGB2BGR)], | |
paths=None, | |
batch_size=1, | |
input_size=input_size, | |
output_dir='output', | |
visualization=False | |
) | |
mask = Image.fromarray(result[0]['mask']) | |
mask = mask.resize((512, 512)) # Resize to match LaMa input | |
mask.save("./data/data_mask.png") | |
else: # "manual" | |
mask = imageio.imread("./data/data_mask.png") | |
mask = Image.fromarray(mask).convert("L") # Ensure grayscale | |
mask = mask.resize((512, 512)) | |
return prepare_image(mask, (512, 512)) | |
def inpaint_image(image, mask): | |
"""Performs inpainting using the LaMa model.""" | |
outputs = lama_model.run(None, {'image': image, 'mask': mask}) | |
output = outputs[0][0] | |
output = output.transpose(1, 2, 0) | |
output = (output * 255).astype(np.uint8) | |
return Image.fromarray(output) | |
# --- Gradio Interface --- | |
def process_image(input_image, mask_option): | |
"""Main function for Gradio interface.""" | |
imageio.imwrite("./data/data.png", input_image) | |
image = prepare_image(input_image) | |
mask = generate_mask(input_image, method=mask_option) | |
inpainted_image = inpaint_image(image, mask) | |
inpainted_image = inpainted_image.resize(Image.open("./data/data.png").size) | |
inpainted_image.save("./dataout/data_mask.png") | |
return "./dataout/data_mask.png", "./data/data_mask.png" | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(label="Input Image", type="numpy"), | |
gr.Radio(choices=["automatic", "manual"], | |
type="value", label="Masking Option") | |
], | |
outputs=[ | |
gr.Image(type="file", label="Inpainted Image"), | |
gr.Image(type="file", label="Generated Mask") | |
], | |
title="LaMa Image Inpainting", | |
description="Image inpainting with LaMa and U^2-Net. Upload your image and choose automatic or manual masking.", | |
) | |
iface.launch() |