Kkordik's picture
Update app.py
a998593 verified
raw
history blame
5.39 kB
import spaces
import gradio as gr
from diffusers import StableDiffusion3InpaintPipeline, AutoencoderKL
import torch
from PIL import Image, ImageOps
import time
from huggingface_hub import login
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==4.29.0")
login(token=os.getenv("HF_TOKEN"))
# vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
# pipeline = StableDiffusion3InpaintPipeline(vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
pipeline = StableDiffusion3InpaintPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
def get_select_index(evt: gr.SelectData):
return evt.index
@spaces.GPU()
def squarify_image(img):
if(img.height > img.width): bg_size = img.height
else: bg_size = img.width
bg = Image.new(mode="RGB", size=(bg_size,bg_size), color="white")
bg.paste(img, ( int((bg.width - bg.width)/2), 0) )
return bg
@spaces.GPU()
def divisible_by_8(image):
width, height = image.size
# Calculate the new width and height that are divisible by 8
new_width = (width // 8) * 8
new_height = (height // 8) * 8
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
@spaces.GPU()
def restore_version(index, versions):
print('restore version:', index)
final_dict = {'background': versions[index][0], 'layers': None, 'composite': versions[index][0]}
return final_dict
def clear_all():
return gr.update(value=None), gr.update(value=None), gr.update(value=[], visible=False), gr.update(visible=False), gr.update(visible=False)
@spaces.GPU()
def generate(image_editor, prompt, neg_prompt, versions):
start = time.time()
image = image_editor['background'].convert('RGB')
# Resize image
image.thumbnail((1024, 1024))
image = divisible_by_8(image)
original_image_size = image.size
# Mask layer
layer = image_editor["layers"][0].resize(image.size)
# Make image a square
image = squarify_image(image)
# Make sure mask is white with a black background
mask = Image.new("RGBA", image.size, "WHITE")
mask.paste(layer, (0, 0), layer)
mask = ImageOps.invert(mask.convert('L'))
# Inpaint
pipeline.to("cuda")
final_image = pipeline(prompt=prompt,
image=image,
mask_image=mask).images[0]
# Make sure the longest side of image is 1024
if (original_image_size[0] > original_image_size[1]):
original_image_size = ( original_image_size[0] * (1024/original_image_size[0]) , original_image_size[1] * (1024/original_image_size[0]))
else:
original_image_size = (original_image_size[0] * (1024/original_image_size[1]), original_image_size[1] * (1024/original_image_size[1]))
# Crop image to original aspect ratio
final_image = final_image.crop((0, 0, original_image_size[0], original_image_size[1]))
# gradio.ImageEditor requires a diction
final_dict = {'background': final_image, 'layers': None, 'composite': final_image}
# Add generated image to version gallery
if(versions==None):
final_gallery = [image_editor['background'] ,final_image]
else:
final_gallery = versions
final_gallery.append(final_image)
end = time.time()
print('time:', end - start)
return final_dict, gr.Gallery(value=final_gallery, visible=True), gr.update(visible=True), gr.update(visible=True)
with gr.Blocks() as demo:
gr.Markdown("""
# Inpainting SDXL Sketch Pad
by [Tony Assi](https://www.tonyassi.com/)
Please ❤️ this Space. I build custom AI apps for companies. <a href="mailto: tony.assi.media@gmail.com">Email me</a> for business inquiries.
""")
with gr.Row():
with gr.Column():
sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
prompt = gr.Textbox(label="Prompt")
generate_button = gr.Button("Generate")
with gr.Accordion("Advanced Settings", open=False):
neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed')
with gr.Column():
version_gallery = gr.Gallery(label="Versions", type="pil", object_fit='contain', visible=False)
restore_button = gr.Button("Restore Version", visible=False)
clear_button = gr.Button('Clear', visible=False)
selected = gr.Number(show_label=False, visible=False)
gr.Examples(
[[{'background':'./tony.jpg', 'layers':['./tony-mask.jpg'], 'composite':'./tony.jpg'}, 'black and white tuxedo, bowtie', 'ugly', None]],
[sketch_pad, prompt, neg_prompt, version_gallery],
[sketch_pad, version_gallery, restore_button, clear_button],
generate,
cache_examples=True,
)
version_gallery.select(get_select_index, None, selected)
generate_button.click(fn=generate, inputs=[sketch_pad,prompt, neg_prompt, version_gallery], outputs=[sketch_pad, version_gallery, restore_button, clear_button])
restore_button.click(fn=restore_version, inputs=[selected, version_gallery], outputs=sketch_pad)
clear_button.click(clear_all, inputs=None, outputs=[sketch_pad, prompt, version_gallery, restore_button, clear_button])
demo.launch()