SDXL-Lightning / app.py
hideosnes's picture
Update app.py
81c93cd verified
raw
history blame
12.1 kB
import cv2
import torch
import random
import tempfile
import numpy as np
from pathlib import Path
from PIL import Image
from diffusers import (
ControlNetModel,
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
)
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from ip_adapter import IPAdapterXL
from safetensors.torch import load_file
snapshot_download(
repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
)
# global variable
MAX_SEED = np.iinfo(np.int32).max
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
# initialization
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "sdxl_models/image_encoder"
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
controlnet = ControlNetModel.from_pretrained(
controlnet_path, use_safetensors=False, torch_dtype=torch.float16
).to(device)
# load SDXL lightnining
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
variant="fp16",
add_watermarker=False,
).to(device)
pipe.set_progress_bar_config(disable=True)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
)
pipe.unet.load_state_dict(
load_file(
hf_hub_download(
"ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
),
device="cuda",
)
)
# load ip-adapter
# target_blocks=["block"] for original IP-Adapter
# target_blocks=["up_blocks.0.attentions.1"] for style blocks only
# target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
ip_ckpt,
device,
target_blocks=["up_blocks.0.attentions.1"],
)
def resize_img(
input_image,
max_side=1280,
min_side=1024,
size=None,
pad_to_max_side=False,
mode=Image.BILINEAR,
base_pixel_number=64,
):
w, h = input_image.size
if size is not None:
w_resize_new, h_resize_new = size
else:
ratio = min_side / min(h, w)
w, h = round(ratio * w), round(ratio * h)
ratio = max_side / max(h, w)
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
if pad_to_max_side:
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
offset_x = (max_side - w_resize_new) // 2
offset_y = (max_side - h_resize_new) // 2
res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = (
np.array(input_image)
)
input_image = Image.fromarray(res)
return input_image
examples = [
[
"./asset/0.jpg",
None,
"3D model, cute monster, high quality",
1.0,
0.0,
],
[
"./asset/2.jpg",
"./asset/house.jpg",
"3d model, house, kawai, cute, sci-fi, solarpunk, high quality",
1.0,
0.6,
],
]
def run_for_examples(style_image, source_image, prompt, scale, control_scale):
return create_image(
image_pil=style_image,
input_image=source_image,
prompt=prompt,
n_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
scale=scale,
control_scale=control_scale,
guidance_scale=0.0,
num_inference_steps=2,
seed=42,
target="Load only style blocks",
neg_content_prompt="",
neg_content_scale=0,
)
@spaces.GPU
def create_image(
image_pil,
input_image,
prompt,
n_prompt,
scale,
control_scale,
guidance_scale,
num_inference_steps,
seed,
target="Load only style blocks",
neg_content_prompt=None,
neg_content_scale=0,
):
seed = random.randint(0, MAX_SEED) if seed == -1 else seed
if target == "Load original IP-Adapter":
# target_blocks=["blocks"] for original IP-Adapter
ip_model = IPAdapterXL(
pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"]
)
elif target == "Load only style blocks":
# target_blocks=["up_blocks.0.attentions.1"] for style blocks only
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
ip_ckpt,
device,
target_blocks=["up_blocks.0.attentions.1"],
)
elif target == "Load style+layout block":
# target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
ip_ckpt,
device,
target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
)
if input_image is not None:
input_image = resize_img(input_image, max_side=1024)
cv_input_image = pil_to_cv2(input_image)
detected_map = cv2.Canny(cv_input_image, 50, 200)
canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
else:
canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
control_scale = 0
if float(control_scale) == 0:
canny_map = canny_map.resize((1024, 1024))
if len(neg_content_prompt) > 0 and neg_content_scale != 0:
images = ip_model.generate(
pil_image=image_pil,
prompt=prompt,
negative_prompt=n_prompt,
scale=scale,
guidance_scale=guidance_scale,
num_samples=1,
num_inference_steps=num_inference_steps,
seed=seed,
image=canny_map,
controlnet_conditioning_scale=float(control_scale),
neg_content_prompt=neg_content_prompt,
neg_content_scale=neg_content_scale,
)
else:
images = ip_model.generate(
pil_image=image_pil,
prompt=prompt,
negative_prompt=n_prompt,
scale=scale,
guidance_scale=guidance_scale,
num_samples=1,
num_inference_steps=num_inference_steps,
seed=seed,
image=canny_map,
controlnet_conditioning_scale=float(control_scale),
)
image = images[0]
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True)
return Path(tmpfile.name)
def pil_to_cv2(image_pil):
image_np = np.array(image_pil)
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
return image_cv2
# Description
title = r"""
<h1 align="center">I2I mit SDXL-Lightning & IP-Adapter</h1>
"""
description = r"""
<b>ARM <3 GoldExtra Testversion<br>
<b>Wir schauen uns gut funktionierende Prompts. Bitte diese notieren und an Hidéo weiterleiten!</b><br>
"""
article = r"""
<br>
Bei Fragen: <a href="mailto:hideo@artificialmuseum.com">Mail an Hidéo</a>
"""
block = gr.Blocks()
with block:
# description
gr.Markdown(title)
gr.Markdown(description)
with gr.Tabs():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
image_pil = gr.Image(label="Style Image", type="pil")
with gr.Column():
processed_image = gr.Image(label="Preprocess uWu", interactive=False)
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
value="3d render, 3d model, clean 3d style, cute space monster, white backround, cinematic lighting,",
)
scale = gr.Slider(
minimum=0, maximum=2.0, step=0.01, value=1.0, label="Scale"
)
with gr.Accordion(open=False, label="Details (optional)"):
target = gr.Radio(
[
"Load only style blocks",
"Load style+layout block",
"Load original IP-Adapter",
],
value="Load only style blocks",
label="Style mode (optional, sb works best!)",
)
with gr.Column():
src_image_pil = gr.Image(
label="Source Image (optional)", type="pil"
)
control_scale = gr.Slider(
minimum=0,
maximum=1.0,
step=0.01,
value=0.5,
label="ControlNet Scale (test this!)",
)
n_prompt = gr.Textbox(
label="Negative Prompt // n_prompt",
value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
)
neg_content_prompt = gr.Textbox(
label="Negative Content Prompt (Ignore this!)", value=""
)
neg_content_scale = gr.Slider(
minimum=0,
maximum=1.0,
step=0.01,
value=0.5,
label="NCS (Ignore this!) // neg_content_scale",
)
guidance_scale = gr.Slider(
minimum=0,
maximum=10.0,
step=0.01,
value=0.0,
label="Guidance Scale (test this!)",
)
num_inference_steps = gr.Slider(
minimum=2,
maximum=50.0,
step=1.0,
value=2,
label="Inference Steps (optional but test with 2+)",
)
seed = gr.Slider(
minimum=-1,
maximum=MAX_SEED,
value=-1,
step=1,
label="Seed Value (Seed-Proof) // -1 == random",
)
generate_button = gr.Button("Simsalabim")
with gr.Column():
generated_image = gr.Image(label="Magix uWu")
inputs = [
image_pil,
src_image_pil,
prompt,
n_prompt,
scale,
control_scale,
guidance_scale,
num_inference_steps,
seed,
target,
neg_content_prompt,
neg_content_scale,
]
outputs = [generated_image]
gr.on(
triggers=[
# prompt.input,
generate_button.click,
# guidance_scale.input,
# scale.input,
# control_scale.input,
# seed.input,
],
fn=create_image,
inputs=inputs,
outputs=outputs,
show_progress="minimal",
show_api=False,
trigger_mode="always_last",
)
gr.Examples(
examples=examples,
inputs=[image_pil, src_image_pil, prompt, scale, control_scale],
fn=run_for_examples,
outputs=[generated_image],
cache_examples=True,
)
gr.Markdown(article)
block.queue(api_open=False)
block.launch(show_api=False)