File size: 2,144 Bytes
884e760
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
from PIL import Image
import requests
from io import BytesIO
import torch
from torchvision import transforms
from diffusers import AutoencoderKL, LCMScheduler
from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from controlnet import ControlNetModel

# Define helper functions
def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content)).convert("RGB")

def load_model():
    # Load model components
    controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16)
    vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
    pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae)
    pipe.to('cuda')
    return pipe

pipe = load_model()

# Define the inpainting function
def inpaint(image, mask):
    # Process image and mask
    image = image.resize((1024, 1024)).convert("RGB")
    mask = mask.resize((1024, 1024)).convert("L")

    # Transform to tensor
    image_transform = transforms.ToTensor()
    image_tensor = image_transform(image).unsqueeze(0).to('cuda')
    mask_tensor = image_transform(mask).unsqueeze(0).to('cuda')
    mask_tensor = (mask_tensor > 0.5).float()  # binarize mask

    # Generate image
    with torch.no_grad():
        result = pipe(prompt="A park bench", init_image=image_tensor, mask_image=mask_tensor, num_inference_steps=50).images[0]

    return transforms.ToPILImage()(result.squeeze(0))

# Define the interface
interface = gr.Interface(fn=inpaint,
                          inputs=[gr.inputs.Image(type="pil", label="Original Image"), gr.inputs.Image(type="pil", label="Mask Image")],
                          outputs=gr.outputs.Image(type="pil", label="Inpainted Image"),
                          title="Stable Diffusion XL ControlNet Inpainting",
                          description="Upload an image and its corresponding mask to inpaint the specified area.")

if __name__ == "__main__":
    interface.launch()