File size: 1,715 Bytes
08fc0b2
0a2d6b8
08fc0b2
cfe97ad
 
 
08fc0b2
9539987
 
 
 
 
 
 
 
 
08fc0b2
 
 
 
 
 
 
 
f0cce29
429369b
f0cce29
6b87482
 
968c286
08fc0b2
521379d
 
9fa4928
521379d
 
08fc0b2
 
 
 
 
9539987
08fc0b2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
from diffusers import AutoPipelineForInpainting, AutoencoderKL

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")

def generate(image_editor):
    image = image_editor['background'].convert('RGB')
    mask = Image.new("RGBA", image_editor["layers"][0].size, "WHITE") 
    mask.paste(image_editor["layers"][0], (0, 0), image_editor["layers"][0])
    mask = ImageOps.invert(mask.convert('L'))
    
    image.thumbnail((1024, 1024))
    mask.thumbnail((1024, 1024))

    return image_editor, image, mask

with gr.Blocks() as demo:
    gr.Markdown("""
    # Inpainting Sketch Pad
    by [Tony Assi](https://www.tonyassi.com/)
    """)
    
    with gr.Row():
        with gr.Column():
            sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
            generate_button = gr.Button("Generate")
        with gr.Column():
            version_gallery = gr.Gallery(label="Versions")
            restore_button = gr.Button("Restore Version")

    with gr.Accordion("Advanced Settings", open=False):
        neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed, nsfw')
        strength_slider = gr.Slider(0.0, 1.0, value=1.0, label="Strength")
        guidance_slider = gr.Slider(1.0, 15.0, value=7.5, label="Guidance")
        
    with gr.Row():
        out1 = gr.Image()
        out2 = gr.Image()
        out3 = gr.Image()
        
    generate_button.click(fn=generate, inputs=sketch_pad, outputs=[sketch_pad, out1, out2])

demo.launch()