|
import spaces |
|
import gradio as gr |
|
import torch |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from model.clip_away import CLIPAway |
|
import cv2 |
|
import numpy as np |
|
import argparse |
|
|
|
|
|
config = OmegaConf.load("config/inference_config.yaml") |
|
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
|
"botp/stable-diffusion-v1-5-inpainting", torch_dtype=torch.float32 |
|
) |
|
clipaway = CLIPAway( |
|
sd_pipe=sd_pipeline, |
|
image_encoder_path=config.image_encoder_path, |
|
ip_ckpt=config.ip_adapter_ckpt_path, |
|
alpha_clip_path=config.alpha_clip_ckpt_pth, |
|
config=config, |
|
alpha_clip_id=config.alpha_clip_id, |
|
device="cpu", |
|
num_tokens=4 |
|
) |
|
|
|
def dilate_mask(mask, kernel_size=5, iterations=5): |
|
mask = mask.convert("L").resize((512, 512), Image.NEAREST) |
|
kernel = np.ones((kernel_size, kernel_size), np.uint8) |
|
mask = cv2.dilate(np.array(mask), kernel, iterations=iterations) |
|
return Image.fromarray(mask) |
|
|
|
def remove_obj(image, seed): |
|
alpha_channel = image["layers"][0][:, :, 3] |
|
mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8) |
|
uploaded_mask = Image.fromarray(mask) |
|
background = Image.fromarray(img["background"]) |
|
|
|
mask = dilate_mask(uploaded_mask) |
|
seed = int(seed) |
|
latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cpu") |
|
final_image = clipaway.generate( |
|
prompt=[""], scale=1, seed=seed, |
|
pil_image=[background], |
|
alpha=[mask], |
|
strength=1, |
|
latents=latents |
|
)[0] |
|
return final_image |
|
|
|
|
|
examples = [ |
|
["gradio_examples/images/1.jpg", "gradio_examples/masks/1.png", 42], |
|
["gradio_examples/images/2.jpg", "gradio_examples/masks/2.png", 42], |
|
["gradio_examples/images/3.jpg", "gradio_examples/masks/3.png", 464], |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>") |
|
gr.Markdown(""" |
|
<div style='display:flex; justify-content:center; align-items:center;'> |
|
<a href='https://arxiv.org/abs/2406.09368' style="margin-right:10px;">Paper</a> | |
|
<a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> | |
|
<a href='https://github.com/YigitEkin/CLIPAway' style="margin-left:10px;">GitHub</a> |
|
</div> |
|
""") |
|
gr.Markdown(""" |
|
This application allows you to remove objects from images using the CLIPAway method with diffusion models. |
|
To use this tool: |
|
1. Upload an image. (NOTE: We expect a 512x512 image, if you upload a different size, it will be resized to 512x512 which can affect the results.) |
|
2. Upload a pre-defined mask if you have one. (If you don't have a mask, and want to sketch one, |
|
we have provided a gradio demo in our github repository. <br/> Unfortunately, we cannot provide it here due to the compatibility issues with zerogpu.) |
|
3. Set the seed for reproducibility (default is 42). |
|
4. Click 'Remove Object' to process the image. |
|
5. The result will be displayed on the right side. |
|
Note: The mask should be a binary image where the object to be removed is white and the background is black. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.ImageMask(label="Upload Image and Sketch Mask", height=700, layers=False) |
|
seed_input = gr.Number(value=42, label="Seed") |
|
process_button = gr.Button("Remove Object") |
|
with gr.Column(): |
|
result_image = gr.Image(label="Result") |
|
|
|
process_button.click( |
|
fn=remove_obj, |
|
inputs=[image_input, seed_input], |
|
outputs=result_image |
|
) |
|
|
|
|
|
|
|
demo.launch(share=True) |
|
|