File size: 5,760 Bytes
4ed7a63 08fc0b2 62bd330 dd2b2f9 fec64c5 ce02dfa 43e72c6 08fc0b2 6f6d7fe 0467837 cfe97ad bc11c81 6f6d7fe 1e716b6 a1deaea 6f6d7fe a1deaea 4ed7a63 1e716b6 d16752a 6f6d7fe d16752a 6f6d7fe d16752a 6f6d7fe d16752a 1e716b6 7139441 91e5c41 fd69a13 7139441 c0c66f4 b2a4599 c0c66f4 dde83e1 9f655e4 ce02dfa 9539987 6f6d7fe d16752a 965c284 6f6d7fe d16752a 6f6d7fe 3e711ca 6f6d7fe d16752a 9539987 6f6d7fe 4ed7a63 6f6d7fe 77f493d 9f655e4 6f6d7fe 93b65f4 3bec293 6f6d7fe a11c80c 6f6d7fe eec3e26 6f6d7fe 10e906d 6f6d7fe ce02dfa 6f6d7fe e6e1231 08fc0b2 0143c38 d040427 0143c38 08fc0b2 f0cce29 429369b e0d2fcf 77f493d e0fc88a 940b814 77f493d 9f655e4 6b87482 ba2bf14 3d2adfb 44e13e8 2395f61 de7dbc9 1e716b6 bc11c81 9f655e4 7139441 c0c66f4 08fc0b2 6f6d7fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import spaces
import gradio as gr
from diffusers import StableDiffusion3InpaintPipeline, AutoencoderKL
import torch
from PIL import Image, ImageOps
import time
from huggingface_hub import login
import os
login(token=os.getenv("HF_TOKEN"))
# vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
# pipeline = StableDiffusion3InpaintPipeline(vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
pipeline = StableDiffusion3InpaintPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
def get_select_index(evt: gr.SelectData):
return evt.index
# @spaces.GPU()
def squarify_image(img):
if(img.height > img.width): bg_size = img.height
else: bg_size = img.width
bg = Image.new(mode="RGB", size=(bg_size,bg_size), color="white")
bg.paste(img, ( int((bg.width - bg.width)/2), 0) )
return bg
# @spaces.GPU()
def divisible_by_8(image):
width, height = image.size
# Calculate the new width and height that are divisible by 8
new_width = (width // 8) * 8
new_height = (height // 8) * 8
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
# @spaces.GPU()
def restore_version(index, versions):
print('restore version:', index)
final_dict = {'background': versions[index][0], 'layers': None, 'composite': versions[index][0]}
return final_dict
def clear_all():
return gr.update(value=None), gr.update(value=None), gr.update(value=[], visible=False), gr.update(visible=False), gr.update(visible=False)
@spaces.GPU()
def generate(image_editor, prompt, neg_prompt, versions, num_inference_steps, guidance_scale):
start = time.time()
image = image_editor['background'].convert('RGB')
# Resize image
image.thumbnail((1024, 1024))
image = divisible_by_8(image)
original_image_size = image.size
# Mask layer
layer = image_editor["layers"][0].resize(image.size)
# Make image a square
image = squarify_image(image)
# Make sure mask is white with a black background
mask = Image.new("RGBA", image.size, "WHITE")
mask.paste(layer, (0, 0), layer)
mask = ImageOps.invert(mask.convert('L'))
# Inpaint
pipeline.to("cuda")
final_image = pipeline(prompt=prompt,
image=image,
mask_image=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale).images[0]
# Make sure the longest side of image is 1024
if (original_image_size[0] > original_image_size[1]):
original_image_size = ( original_image_size[0] * (1024/original_image_size[0]) , original_image_size[1] * (1024/original_image_size[0]))
else:
original_image_size = (original_image_size[0] * (1024/original_image_size[1]), original_image_size[1] * (1024/original_image_size[1]))
# Crop image to original aspect ratio
final_image = final_image.crop((0, 0, original_image_size[0], original_image_size[1]))
# gradio.ImageEditor requires a diction
final_dict = {'background': final_image, 'layers': None, 'composite': final_image}
# Add generated image to version gallery
if(versions==None):
final_gallery = [image_editor['background'] ,final_image]
else:
final_gallery = versions
final_gallery.append(final_image)
end = time.time()
print('time:', end - start)
return final_dict, gr.Gallery(value=final_gallery, visible=True), gr.update(visible=True), gr.update(visible=True)
with gr.Blocks() as demo:
gr.Markdown("""
# Inpainting SD3 Sketch Pad
Please ❤️ this Space
""")
with gr.Row():
with gr.Column():
sketch_pad = gr.ImageMask(type='pil', label='Inpaint')
prompt = gr.Textbox(label="Prompt")
generate_button = gr.Button(value="Inpaint", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
neg_prompt = gr.Textbox(label='Negative Prompt', value='ugly, deformed')
num_inference_steps = gr.Slider(minimum = 10, maximum = 100, value = 30, step = 1, label = "Number of inference steps", info = "lower=faster, higher=image quality")
guidance_scale = gr.Slider(minimum = 1, maximum = 13, value = 7, step = 0.1, label = "Classifier-Free Guidance Scale", info = "lower=image quality, higher=follow the prompt")
with gr.Column():
version_gallery = gr.Gallery(label="Versions", type="pil", object_fit='contain', visible=False)
restore_button = gr.Button("Restore Version", visible=False)
clear_button = gr.Button('Clear', visible=False)
selected = gr.Number(show_label=False, visible=False)
# gr.Examples(
# [[{'background':'./tony.jpg', 'layers':['./tony-mask.jpg'], 'composite':'./tony.jpg'}, 'black and white tuxedo, bowtie', 'ugly', None]],
# [sketch_pad, prompt, neg_prompt, version_gallery],
# [sketch_pad, version_gallery, restore_button, clear_button],
# generate,
# cache_examples=True,
# )
version_gallery.select(get_select_index, None, selected)
generate_button.click(fn=generate, inputs=[sketch_pad,prompt, neg_prompt, version_gallery, num_inference_steps, guidance_scale], outputs=[sketch_pad, version_gallery, restore_button, clear_button])
restore_button.click(fn=restore_version, inputs=[selected, version_gallery], outputs=sketch_pad)
clear_button.click(clear_all, inputs=None, outputs=[sketch_pad, prompt, version_gallery, restore_button, clear_button])
demo.launch() |