Kkordik's picture
Update app.py
dc12a0a verified
raw
history blame
4.67 kB
import spaces
import gradio as gr
from diffusers import StableDiffusion3InpaintPipeline, AutoencoderKL
import torch
from PIL import Image, ImageOps
import time
from huggingface_hub import login
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==4.29.0")
login(token=os.getenv("HF_TOKEN"))
pipeline = StableDiffusion3InpaintPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
def get_select_index(evt: gr.SelectData):
return evt.index
@spaces.GPU()
def squarify_image(img):
if img.height > img.width:
bg_size = img.height
else:
bg_size = img.width
bg = Image.new(mode="RGB", size=(bg_size, bg_size), color="white")
bg.paste(img, (int((bg.width - bg.width) / 2), 0))
return bg
@spaces.GPU()
def divisible_by_8(image):
width, height = image.size
new_width = (width // 8) * 8
new_height = (height // 8) * 8
resized_image = image.resize((new_width, new_height))
return resized_image
@spaces.GPU()
def restore_version(index, versions):
print('restore version:', index)
final_dict = {'background': versions[index][0], 'layers': None, 'composite': versions[index][0]}
return final_dict
def clear_all():
return gr.update(value=None), gr.update(value=None), gr.update(value=[], visible=False), gr.update(visible=False), gr.update(visible=False)
@spaces.GPU()
def generate(image_editor, prompt, neg_prompt, versions):
start = time.time()
image = image_editor['background'].convert('RGB')
image.thumbnail((1024, 1024))
image = divisible_by_8(image)
original_image_size = image.size
layer = image_editor["layers"][0].resize(image.size)
image = squarify_image(image)
mask = Image.new("RGBA", image.size, "WHITE")
mask.paste(layer, (0, 0), layer)
mask = ImageOps.invert(mask.convert('L'))
pipeline.to("cuda")
final_image = pipeline(prompt=prompt, image=image, mask_image=mask).images[0]
if original_image_size[0] > original_image_size[1]:
original_image_size = (original_image_size[0] * (1024/original_image_size[0]), original_image_size[1] * (1024/original_image_size[0]))
else:
original_image_size = (original_image_size[0] * (1024/original_image_size[1]), original_image_size[1] * (1024/original_image_size[1]))
final_image = final_image.crop((0, 0, original_image_size[0], original_image_size[1]))
final_dict = {'background': final_image, 'layers': None, 'composite': final_image}
if versions is None:
final_gallery = [image_editor['background'], final_image]
else:
final_gallery = versions
final_gallery.append(final_image)
end = time.time()
print('time:', end - start)
return final_dict, gr.Gallery(value=final_gallery, visible=True), gr.update(visible=True), gr.update(visible=True)
with gr.Blocks() as demo:
gr.Markdown("""
# Inpainting SDXL Sketch Pad
by [Tony Assi](https://www.tonyassi.com/)
Please ❤️ this Space. I build custom AI apps for companies. <a href="mailto:tony.assi.media@gmail.com">Email me</a> for business inquiries.
""")
with gr.Row():
with gr.Column():
sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
prompt = gr.Textbox(label="Prompt")
generate_button = gr.Button("Generate")
with gr.Accordion("Advanced Settings", open=False):
neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed')
with gr.Column():
version_gallery = gr.Gallery(label="Versions", type="pil", object_fit='contain', visible=False)
restore_button = gr.Button("Restore Version", visible=False)
clear_button = gr.Button('Clear', visible=False)
selected = gr.Number(show_label=False, visible=False)
gr.Examples(
[[{'background':'./tony.jpg', 'layers':['./tony-mask.jpg'], 'composite':'./tony.jpg'}, 'black and white tuxedo, bowtie', 'ugly', None]],
[sketch_pad, prompt, neg_prompt, version_gallery],
[sketch_pad, version_gallery, restore_button, clear_button],
generate,
cache_examples=True,
)
version_gallery.select(get_select_index, None, selected)
generate_button.click(fn=generate, inputs=[sketch_pad, prompt, neg_prompt, version_gallery], outputs=[sketch_pad, version_gallery, restore_button, clear_button])
restore_button.click(fn=restore_version, inputs=[selected, version_gallery], outputs=sketch_pad)
clear_button.click(clear_all, inputs=None, outputs=[sketch_pad, prompt, version_gallery, restore_button, clear_button])
demo.launch()