yonishafir commited on
Commit
884e760
·
verified ·
1 Parent(s): d932235

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ import torch
6
+ from torchvision import transforms
7
+ from diffusers import AutoencoderKL, LCMScheduler
8
+ from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
9
+ from controlnet import ControlNetModel
10
+
11
+ # Define helper functions
12
+ def download_image(url):
13
+ response = requests.get(url)
14
+ return Image.open(BytesIO(response.content)).convert("RGB")
15
+
16
+ def load_model():
17
+ # Load model components
18
+ controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16)
19
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
20
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae)
21
+ pipe.to('cuda')
22
+ return pipe
23
+
24
+ pipe = load_model()
25
+
26
+ # Define the inpainting function
27
+ def inpaint(image, mask):
28
+ # Process image and mask
29
+ image = image.resize((1024, 1024)).convert("RGB")
30
+ mask = mask.resize((1024, 1024)).convert("L")
31
+
32
+ # Transform to tensor
33
+ image_transform = transforms.ToTensor()
34
+ image_tensor = image_transform(image).unsqueeze(0).to('cuda')
35
+ mask_tensor = image_transform(mask).unsqueeze(0).to('cuda')
36
+ mask_tensor = (mask_tensor > 0.5).float() # binarize mask
37
+
38
+ # Generate image
39
+ with torch.no_grad():
40
+ result = pipe(prompt="A park bench", init_image=image_tensor, mask_image=mask_tensor, num_inference_steps=50).images[0]
41
+
42
+ return transforms.ToPILImage()(result.squeeze(0))
43
+
44
+ # Define the interface
45
+ interface = gr.Interface(fn=inpaint,
46
+ inputs=[gr.inputs.Image(type="pil", label="Original Image"), gr.inputs.Image(type="pil", label="Mask Image")],
47
+ outputs=gr.outputs.Image(type="pil", label="Inpainted Image"),
48
+ title="Stable Diffusion XL ControlNet Inpainting",
49
+ description="Upload an image and its corresponding mask to inpaint the specified area.")
50
+
51
+ if __name__ == "__main__":
52
+ interface.launch()