SDXL-Lightning / app.py
hideosnes's picture
Update app.py
bad5e96 verified
raw
history blame
No virus
14.8 kB
import cv2
import torch
import random
import tempfile
import numpy as np
from pathlib import Path
from PIL import Image
from diffusers import (
ControlNetModel,
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
)
import spaces
import gradio as gr
from huggingface_hub import hf_hub_download, snapshot_download
from ip_adapter import IPAdapterXL
from safetensors.torch import load_file
from rembg import remove
snapshot_download(
repo_id="h94/IP-Adapter", allow_patterns="sdxl_models/*", local_dir="."
)
# global variable
MAX_SEED = np.iinfo(np.int32).max
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
# initialization
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "sdxl_models/image_encoder"
ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
controlnet = ControlNetModel.from_pretrained(
controlnet_path, use_safetensors=False, torch_dtype=torch.float16
).to(device)
# load SDXL lightnining
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
variant="fp16",
add_watermarker=False,
).to(device)
pipe.set_progress_bar_config(disable=True)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
)
pipe.unet.load_state_dict(
load_file(
hf_hub_download(
"ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
),
device="cuda",
)
)
# load ip-adapter
# target_blocks=["block"] for original IP-Adapter
# target_blocks=["up_blocks.0.attentions.1"] for style blocks only
# target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
ip_ckpt,
device,
target_blocks=["up_blocks.0.attentions.1"],
)
def segment(input_image):
# Convert PIL image to NumPy array
image_np = np.array(input_image)
# Remove background
image_np = remove(image_np)
# Convert back to PIL image
input_image = Image.fromarray(image_np)
return input_image
def resize_img(
input_image,
# max_side=1280,
# min_side=1024,
max_side=1024,
min_side=512,
size=None,
pad_to_max_side=False,
mode=Image.BILINEAR,
base_pixel_number=64,
):
w, h = input_image.size
if size is not None:
w_resize_new, h_resize_new = size
else:
ratio = min_side / min(h, w)
w, h = round(ratio * w), round(ratio * h)
ratio = max_side / max(h, w)
input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
input_image = input_image.resize([w_resize_new, h_resize_new], mode)
if pad_to_max_side:
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
offset_x = (max_side - w_resize_new) // 2
offset_y = (max_side - h_resize_new) // 2
res[offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new] = (
np.array(input_image)
)
input_image = Image.fromarray(res)
return input_image
examples = [
[
"./asset/0.jpg",
None,
"3d render, 3d model, clean 3d style, cute space monster on mars, isolated clean white background, cinematic lighting",
1.0,
0.0,
0.0,
],
[
"./asset/zeichnung1.jpg",
"./asset/zeichnung1mask.png",
"3d render, 3d model, clean 3d style, space ship, isolated clean white background, cinematic lighting",
0.95,
0.5,
0.5,
],
[
"./asset/zeichnung2.jpg",
"./asset/zeichnung2mask.png",
"3d render, 3d model, clean 3d style, space station on mars, isolated clean white background, cinematic lighting",
1.0,
0.5,
0.3,
],
[
"./asset/zeichnung3.jpg",
"./asset/zeichnung3mask.png",
"3d render, 3d model, clean 3d style, cute space astronaut on mars, isolated clean white background, cinematic lighting",
1.0,
0.45,
1.5,
],
]
def run_for_examples(style_image, source_image, prompt, scale, control_scale, guidance_scale):
# added
# if source_image is not None:
# source_image = segment(source_image)
#/added
return create_image(
image_pil=style_image,
input_image=source_image,
prompt=prompt,
n_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
scale=scale,
control_scale=control_scale,
guidance_scale=0.0,
num_inference_steps=2,
seed=42,
target="Load only style blocks",
neg_content_prompt="",
neg_content_scale=0,
)
@spaces.GPU
def create_image(
image_pil,
input_image,
# added
# src_image_pil,
#/added
prompt,
n_prompt,
scale,
control_scale,
guidance_scale,
num_inference_steps,
seed,
target="Load only style blocks",
neg_content_prompt=None,
neg_content_scale=0,
):
seed = random.randint(0, MAX_SEED) if seed == -1 else seed
if target == "Load original IP-Adapter":
# target_blocks=["blocks"] for original IP-Adapter
ip_model = IPAdapterXL(
pipe, image_encoder_path, ip_ckpt, device, target_blocks=["blocks"]
)
elif target == "Load only style blocks":
# target_blocks=["up_blocks.0.attentions.1"] for style blocks only
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
ip_ckpt,
device,
target_blocks=["up_blocks.0.attentions.1"],
)
elif target == "Load style+layout block":
# target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"] # for style+layout blocks
ip_model = IPAdapterXL(
pipe,
image_encoder_path,
ip_ckpt,
device,
target_blocks=["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"],
)
if input_image is not None:
input_image = resize_img(input_image, max_side=1024)
cv_input_image = pil_to_cv2(input_image)
detected_map = cv2.Canny(cv_input_image, 50, 200)
canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
else:
canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
control_scale = 0
if float(control_scale) == 0:
canny_map = canny_map.resize((1024, 1024))
if len(neg_content_prompt) > 0 and neg_content_scale != 0:
images = ip_model.generate(
pil_image=image_pil,
prompt=prompt,
negative_prompt=n_prompt,
scale=scale,
guidance_scale=guidance_scale,
num_samples=1,
num_inference_steps=num_inference_steps,
seed=seed,
image=canny_map,
controlnet_conditioning_scale=float(control_scale),
neg_content_prompt=neg_content_prompt,
neg_content_scale=neg_content_scale,
)
else:
images = ip_model.generate(
pil_image=image_pil,
prompt=prompt,
negative_prompt=n_prompt,
scale=scale,
guidance_scale=guidance_scale,
num_samples=1,
num_inference_steps=num_inference_steps,
seed=seed,
image=canny_map,
controlnet_conditioning_scale=float(control_scale),
)
image = images[0]
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile:
image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True)
return Path(tmpfile.name)
def pil_to_cv2(image_pil):
image_np = np.array(image_pil)
image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
return image_cv2
def segment(image_pil):
return remove(image_pil)
# Description
title = r"""
<h1 align="center">Pipeline 1: 2.5D( image to image )</h1>
"""
description = r"""
<b>ARM <3 GoldExtra KI-Demo #1<b/>
<p>Bei der finalen Version läuft dieser Teil auf einer dedicated GPU! Bitte beachten: Bei der Demo mit einer "Zero" GPU kann es zu kurzen Wartezeiten
und Leistungsbeschränkungen kommen, da Zero's günstig zum Testen sind, jedoch von vielen Menschen gleichzeitig verwendet werden. Im Fall der Fälle bitte
kurz warten, dann geht sie wieder.</p><br>
<p>Normalerweise sollte ein <b>Scale<b/> von '1' die besten Ergebnisse liefern. Das ist ein normierter Wert, den wir erreichen wollen. Bei '1.x' bleibt
die KI näher beim Bild und bei '0.x' fängt die KI an Bildinhalte dazu zu erfinden.<br>
Man kann auch die Details öffnen und an den feineren Rädchend drehen. Die fürs Projekt spannenden Einstellungen sind der <b>ControlNet Scale</b>, der verändert,
wie wichtig der KI das Input Bild ist und der <b>Guidance Scale</b>, bei dem es um die Wichtigkeit des Text-Prompts geht. Die restlichen Einstellungen könnt ihr ignorieren.
Das Source-Image ist sowas wie eine Malvorlage für die KI. D.h. sie wird versuchen die Inhalte der Vorlage so genau wie möglich bei der Bild-Synthese zu verwenden.
Bei der finalen Version, wird hier automatisch eine Vorlage der Zeichnung erstellt, damit die Ergebnisse der 2.5D-Synthese möglichst nah (Farbe, Form, Komposition)
der Kidnerzeichnung ist, und gleichzeitig ein Bild generiert, das vergleichsweise gute Ergebnisse bei der 3D Synthese erziehlt. Der zeichnerische, leicht verwaschene Stil
der 2.5D Synthese ist erwünscht und wird für die 3D Synthese gebraucht.</p>
"""
article = r"""
<br>
Bei technischen Fragen schickt bitte ein <a href="mailto:hideo@artificialmuseum.com">Mail an Hidéo</a> oder bei allgemeinen Fragen schickt
<a href="mailto:team@artificialmuseum.com">Mail an das ARM-Team</a>.
"""
block = gr.Blocks()
with block:
# description
gr.Markdown(title)
gr.Markdown(description)
with gr.Tabs():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row():
# with gr.Column():
image_pil = gr.Image(label="Style Image", type="pil")
with gr.Column():
prompt = gr.Textbox(
label="Text-Prompts",
value="<How (Guidance)>, <What (Objekt)>, <Where (Location)>, <With (Qualität)>",
)
scale = gr.Slider(
minimum=0, maximum=2.0, step=0.01, value=1.0, label="Scale"
)
with gr.Accordion(open=False, label="Motorhaube (optional)"):
target = gr.Radio(
[
"Load only style blocks",
"Load style+layout block",
"Load original IP-Adapter",
],
value="Load only style blocks",
label="Adapter Modus",
)
with gr.Column():
src_image_pil = gr.Image(
label="Source Image (optional/wird generiert)", type="pil"
)
control_scale = gr.Slider(
minimum=0,
maximum=1.0,
step=0.01,
value=0.5,
label="ControlNet Scale (Relevanz: Input-Bild)",
)
n_prompt = gr.Textbox(
label="Negative Prompt (Ignore this!)",
value="text, watermark, lowres, low quality, worst quality, deformed, detached, broken, glitch, low contrast, noisy, saturation, blurry, blur",
)
neg_content_prompt = gr.Textbox(
label="Negative Content Prompt (Ignore this!)", value=""
)
neg_content_scale = gr.Slider(
minimum=0,
maximum=1.0,
step=0.01,
value=0.5,
label="NCS (Ignore this!) // neg_content_scale",
)
guidance_scale = gr.Slider(
minimum=0,
maximum=10.0,
step=0.01,
value=0.0,
label="Guidance Scale (Relevanz: Text-Prompt)",
)
num_inference_steps = gr.Slider(
minimum=2,
maximum=50.0,
step=1.0,
value=2,
label="Inference Steps (Stärke der Bildsynthese)",
)
seed = gr.Slider(
minimum=-1,
maximum=MAX_SEED,
value=-1,
step=1,
label="Seed Value (Seed-Proof) // -1 == random",
)
generate_button = gr.Button("Simsalabim")
with gr.Column():
generated_image = gr.Image(label="Magix uWu")
inputs = [
image_pil,
src_image_pil,
prompt,
n_prompt,
scale,
control_scale,
guidance_scale,
num_inference_steps,
seed,
target,
neg_content_prompt,
neg_content_scale,
]
outputs = [generated_image]
gr.on(
triggers=[
# prompt.input,
generate_button.click,
# guidance_scale.input,
# scale.input,
# control_scale.input,
# seed.input,
],
fn=create_image,
inputs=inputs,
outputs=outputs,
show_progress="minimal",
show_api=False,
trigger_mode="always_last",
)
gr.Examples(
examples=examples,
inputs=[image_pil, src_image_pil, prompt, scale, control_scale, guidance_scale],
fn=run_for_examples,
outputs=[generated_image],
cache_examples=True,
)
gr.Markdown(article)
block.queue(api_open=False)
block.launch(show_api=False)