File size: 3,898 Bytes
44e4a4a
92e0882
 
 
 
 
 
 
 
 
 
 
b8231cb
92e0882
02ab530
92e0882
 
 
 
 
 
 
 
1496374
92e0882
 
 
 
c80b748
92e0882
 
 
 
2c45810
45c2198
2c45810
 
 
 
c80b748
92e0882
2c45810
92e0882
 
2c45810
 
 
 
92e0882
 
 
 
 
8a6f0b6
 
 
92e0882
 
5fb2341
92e0882
 
 
c80b748
 
 
92e0882
 
 
 
 
c80b748
8a6f0b6
 
 
 
 
92e0882
 
 
 
 
5fb2341
92e0882
 
 
 
 
 
 
2c45810
92e0882
 
 
46be095
92e0882
45c2198
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import spaces
import gradio as gr
import torch
from omegaconf import OmegaConf
from PIL import Image
from diffusers import StableDiffusionInpaintPipeline
from model.clip_away import CLIPAway
import cv2
import numpy as np
import argparse

# Load configuration and models
config = OmegaConf.load("config/inference_config.yaml")
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "botp/stable-diffusion-v1-5-inpainting", torch_dtype=torch.float32
)
clipaway = CLIPAway(
    sd_pipe=sd_pipeline, 
    image_encoder_path=config.image_encoder_path,
    ip_ckpt=config.ip_adapter_ckpt_path, 
    alpha_clip_path=config.alpha_clip_ckpt_pth, 
    config=config, 
    alpha_clip_id=config.alpha_clip_id, 
    device="cpu", 
    num_tokens=4
)

def dilate_mask(mask, kernel_size=5, iterations=5):
    mask = mask.convert("L").resize((512, 512), Image.NEAREST)
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
    return Image.fromarray(mask)

def remove_obj(image, seed):
    alpha_channel = image["layers"][0][:, :, 3]
    mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8)
    uploaded_mask = Image.fromarray(mask)
    background = Image.fromarray(img["background"])
    
    mask = dilate_mask(uploaded_mask)
    seed = int(seed)
    latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cpu")
    final_image = clipaway.generate(
        prompt=[""], scale=1, seed=seed,
        pil_image=[background], 
        alpha=[mask], 
        strength=1, 
        latents=latents
    )[0]
    return final_image

# Define example data
examples = [
    ["gradio_examples/images/1.jpg", "gradio_examples/masks/1.png", 42],
    ["gradio_examples/images/2.jpg", "gradio_examples/masks/2.png", 42],
    ["gradio_examples/images/3.jpg", "gradio_examples/masks/3.png", 464],
]

with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>")
    gr.Markdown("""
        <div style='display:flex; justify-content:center; align-items:center;'>
            <a href='https://arxiv.org/abs/2406.09368' style="margin-right:10px;">Paper</a> |
            <a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> |
            <a href='https://github.com/YigitEkin/CLIPAway' style="margin-left:10px;">GitHub</a>
        </div>
    """)
    gr.Markdown("""
            This application allows you to remove objects from images using the CLIPAway method with diffusion models.
            To use this tool:
            1. Upload an image. (NOTE: We expect a 512x512 image, if you upload a different size, it will be resized to 512x512 which can affect the results.)
            2. Upload a pre-defined mask if you have one. (If you don't have a mask, and want to sketch one, 
            we have provided a gradio demo in our github repository. <br/> Unfortunately, we cannot provide it here due to the compatibility issues with zerogpu.)
            3. Set the seed for reproducibility (default is 42).
            4. Click 'Remove Object' to process the image.
            5. The result will be displayed on the right side.
            Note: The mask should be a binary image where the object to be removed is white and the background is black.
    """)
    
    with gr.Row():
        with gr.Column():
            image_input = gr.ImageMask(label="Upload Image and Sketch Mask", height=700, layers=False)
            seed_input = gr.Number(value=42, label="Seed")
            process_button = gr.Button("Remove Object")
        with gr.Column():
            result_image = gr.Image(label="Result")
    
    process_button.click(
        fn=remove_obj,
        inputs=[image_input, seed_input],
        outputs=result_image
    )



demo.launch(share=True)