ctrl-x / app_ctrlx.py
multimodalart's picture
Update app_ctrlx.py
c343272 verified
raw
history blame
17.6 kB
from argparse import ArgumentParser
from diffusers import DDIMScheduler, StableDiffusionXLImg2ImgPipeline
import gradio as gr
import torch
import yaml
from ctrl_x.pipelines.pipeline_sdxl import CtrlXStableDiffusionXLPipeline
from ctrl_x.utils import *
from ctrl_x.utils.sdxl import *
import spaces
parser = ArgumentParser()
parser.add_argument("-m", "--model", type=str, default=None) # Optionally, load model checkpoint from single file
args = parser.parse_args()
torch.backends.cudnn.enabled = False # Sometimes necessary to suppress CUDNN_STATUS_NOT_SUPPORTED
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
refiner_id_or_path = "stabilityai/stable-diffusion-xl-refiner-1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
#variant = "fp16" if device == "cuda" else "fp32"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") # TODO: Support other schedulers
if args.model is None:
pipe = CtrlXStableDiffusionXLPipeline.from_pretrained(
model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype, use_safetensors=True
)
else:
print(f"Using weights {args.model} for SDXL base model.")
pipe = CtrlXStableDiffusionXLPipeline.from_single_file(args.model, scheduler=scheduler, torch_dtype=torch_dtype)
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
refiner_id_or_path, scheduler=scheduler, text_encoder_2=pipe.text_encoder_2, vae=pipe.vae,
torch_dtype=torch_dtype, use_safetensors=True,
)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
refiner = refiner.to("cuda")
def get_control_config(structure_schedule, appearance_schedule):
s = structure_schedule
a = appearance_schedule
control_config =\
f"""control_schedule:
# structure_conv structure_attn appearance_attn conv/attn
encoder: # (num layers)
0: [[ ], [ ], [ ]] # 2/0
1: [[ ], [ ], [{a}, {a} ]] # 2/2
2: [[ ], [ ], [{a}, {a} ]] # 2/2
middle: [[ ], [ ], [ ]] # 2/1
decoder:
0: [[{s} ], [{s}, {s}, {s}], [0.0, {a}, {a}]] # 3/3
1: [[ ], [ ], [{a}, {a} ]] # 3/3
2: [[ ], [ ], [ ]] # 3/0
control_target:
- [output_tensor] # structure_conv choices: {{hidden_states, output_tensor}}
- [query, key] # structure_attn choices: {{query, key, value}}
- [before] # appearance_attn choices: {{before, value, after}}
self_recurrence_schedule:
- [0.1, 0.5, 2] # format: [start, end, num_recurrence]"""
return control_config
css = """
.config textarea {font-family: monospace; font-size: 80%; white-space: pre}
.mono {font-family: monospace}
"""
title = """
<div style="display: flex; align-items: center; justify-content: center;margin-bottom: -15px">
<h1 style="margin-left: 12px;text-align: center;display: inline-block">
Ctrl-X: Controlling Structure and Appearance for Text-To-Image Generation Without Guidance
</h1>
<h3 style="display: inline-block; margin-left: 10px; margin-top: 7.5px; font-weight: 500">
SDXL v1.0
</h3>
</div>
<div style="display: flex; align-items: center; justify-content: center;margin-bottom: 25px">
<h3 style="text-align: center">
[<a href="https://genforce.github.io/ctrl-x/">Page</a>]
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
[<a href="https://arxiv.org/abs/2406.07540">Paper</a>]
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
[<a href="https://github.com/genforce/ctrl-x">Code</a>]
</h3>
</div>
"""
description = """<div>
<p>
<b>Ctrl-X</b> is a simple training-free and guidance-free framework for text-to-image (T2I) generation with
structure and appearance control. Given structure and appearance images, Ctrl-X designs feedforward structure
control to enable structure alignment with the arbitrary structure image and semantic-aware appearance transfer
to facilitate the appearance transfer from the appearance image.
</p>
<p>
Here are some notes and tips for this demo:
</p>
<ul>
<li> On input images:
<ul>
<li>
If both the structure and appearance images are provided, then Ctrl-X does <i>structure and
appearance</i> control.
</li>
<li>
If only the structure image is provided, then Ctrl-X does <i>structure-only</i> control and the
appearance image is jointly generated with the output image.
</li>
<li>
Similarly, if only the appearance image is provided, then Ctrl-X does <i>appearance-only</i>
control.
</li>
</ul>
</li>
<li> On prompts:
<ul>
<li>
Though the output prompt can affect the output image to a noticeable extent, the "accuracy" of the
structure and appearance prompts are not impactful to the final image.
</li>
<li>
If the structure or appearance prompt is left blank, then it uses the (non-optional) output prompt
by default.
</li>
</ul>
</li>
<li> On control schedules:
<ul>
<li>
When "Use advanced config" is <b>OFF</b>, the demo uses the structure guidance
(<span class="mono">structure_conv</span> and <span class="mono">structure_attn</span>
in the advanced config) and appearance guidance (<span class="mono">appearance_attn</span> in the
advanced config) sliders to change the control schedules.
</li>
<li>
Otherwise, the demo uses "Advanced control config," which allows per-layer structure and
appearance schedule control, along with self-recurrence control. <i>This should be used
carefully</i>, and we recommend switching "Use advanced config" <b>OFF</b> in most cases. (For the
examples provided at the bottom of the demo, the advanced config uses the default schedules that
may not be the best settings for these examples.)
</li>
</ul>
</li>
</ul>
<p>
Have fun! :D
</p>
</div>
"""
@spaces.GPU
def inference(
structure_image,
appearance_image,
prompt,
structure_prompt,
appearance_prompt,
positive_prompt="high quality",
negative_prompt="ugly, blurry, dark, low res, unrealistic",
guidance_scale=5.0,
structure_guidance_scale=5.0,
appearance_guidance_scale=5.0,
num_inference_steps=28,
eta=1.0,
seed=42,
width=1024,
height=1024,
structure_schedule=0.6,
appearance_schedule=0.6,
use_advanced_config=False,
control_config="",
progress=gr.Progress(track_tqdm=True)
):
torch.manual_seed(seed)
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipe.scheduler.timesteps
print(f"\nUsing the following control config (use_advanced_config={use_advanced_config}):")
if not use_advanced_config:
control_config = get_control_config(structure_schedule, appearance_schedule)
print(control_config, end="\n\n")
config = yaml.safe_load(control_config)
register_control(
model = pipe,
timesteps = timesteps,
control_schedule = config["control_schedule"],
control_target = config["control_target"],
)
pipe.safety_checker = None
pipe.requires_safety_checker = False
self_recurrence_schedule = get_self_recurrence_schedule(config["self_recurrence_schedule"], num_inference_steps)
pipe.set_progress_bar_config(desc="Ctrl-X inference")
refiner.set_progress_bar_config(desc="Refiner")
result, structure, appearance = pipe(
prompt = prompt,
structure_prompt = structure_prompt,
appearance_prompt = appearance_prompt,
structure_image = structure_image,
appearance_image = appearance_image,
num_inference_steps = num_inference_steps,
negative_prompt = negative_prompt,
positive_prompt = positive_prompt,
height = height,
width = width,
guidance_scale = guidance_scale,
structure_guidance_scale = structure_guidance_scale,
appearance_guidance_scale = appearance_guidance_scale,
eta = eta,
output_type = "pil",
return_dict = False,
control_schedule = config["control_schedule"],
self_recurrence_schedule = self_recurrence_schedule,
)
result_refiner = refiner(
image = pipe.refiner_args["latents"],
prompt = pipe.refiner_args["prompt"],
negative_prompt = pipe.refiner_args["negative_prompt"],
height = height,
width = width,
num_inference_steps = num_inference_steps,
guidance_scale = guidance_scale,
guidance_rescale = 0.7,
num_images_per_prompt = 1,
eta = eta,
output_type = "pil",
).images
del pipe.refiner_args
return [result[0], result_refiner[0], structure[0], appearance[0]]
with gr.Blocks(theme=gr.themes.Default(), css=css, title="Ctrl-X (SDXL v1.0)") as app:
gr.HTML(title)
with gr.Accordion("Instructions", open=False):
gr.HTML(description)
with gr.Row():
with gr.Column(scale=45):
with gr.Group():
kwargs = {} # {"width": 400, "height": 400}
with gr.Row():
structure_image = gr.Image(label="Upload structure image (optional)", type="pil", **kwargs)
appearance_image = gr.Image(label="Upload appearance image (optional)", type="pil", **kwargs)
with gr.Row():
structure_prompt = gr.Textbox(label="Structure prompt (optional)", placeholder="Describes the structure image")
appearance_prompt = gr.Textbox(label="Appearance prompt (optional)", placeholder="Describes the style image")
with gr.Row():
prompt = gr.Textbox(label="Output prompt", placeholder="Prompt which describes the output image")
with gr.Row():
positive_prompt = gr.Textbox(label="Positive prompt", value="high quality", placeholder="")
negative_prompt = gr.Textbox(label="Negative prompt", value="ugly, blurry, dark, low res, unrealistic", placeholder="")
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
guidance_scale = gr.Slider(label="Target guidance scale", value=5.0, minimum=1, maximum=10)
structure_guidance_scale = gr.Slider(label="Structure guidance scale", value=5.0, minimum=1, maximum=10)
appearance_guidance_scale = gr.Slider(label="Appearance guidance scale", value=5.0, minimum=1, maximum=10)
with gr.Row():
num_inference_steps = gr.Slider(label="# inference steps", value=28, minimum=1, maximum=200, step=1)
eta = gr.Slider(label="Eta (noise)", value=1.0, minimum=0, maximum=1.0, step=0.01)
seed = gr.Slider(0, 2147483647, label="Seed", value=90095, step=1)
with gr.Row():
width = gr.Slider(label="Width", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor)
height = gr.Slider(label="Height", value=1024, minimum=256, maximum=2048, step=pipe.vae_scale_factor)
with gr.Row():
structure_schedule = gr.Slider(label="Structure schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2)
appearance_schedule = gr.Slider(label="Appearance schedule", value=0.6, minimum=0.0, maximum=1.0, step=0.01, scale=2)
use_advanced_config = gr.Checkbox(label="Use advanced config", value=False, scale=1)
with gr.Row():
control_config = gr.Textbox(
label="Advanced control config", lines=20, value=get_control_config(0.6, 0.6), elem_classes=["config"], visible=False,
)
use_advanced_config.change(
fn=lambda value: gr.update(visible=value), inputs=use_advanced_config, outputs=control_config,
)
with gr.Row():
generate = gr.Button(value="Run")
with gr.Column(scale=55):
with gr.Group():
with gr.Row():
result_refiner = gr.Image(label="Output image w/ refiner", format="jpg", **kwargs)
with gr.Row():
result = gr.Image(label="Output image", format="jpg", **kwargs)
structure_recon = gr.Image(label="Structure image", format="jpg", **kwargs)
appearance_recon = gr.Image(label="Style image", format="jpg", **kwargs)
inputs = [
structure_image, appearance_image,
prompt, structure_prompt, appearance_prompt,
positive_prompt, negative_prompt,
guidance_scale, structure_guidance_scale, appearance_guidance_scale,
num_inference_steps, eta, seed,
width, height,
structure_schedule, appearance_schedule, use_advanced_config,
control_config,
]
outputs = [result, result_refiner, structure_recon, appearance_recon]
generate.click(inference, inputs=inputs, outputs=outputs)
examples = gr.Examples(
[
[
"assets/images/horse__point_cloud.jpg",
"assets/images/horse.jpg",
"a photo of a horse standing on grass",
"a 3D point cloud of a horse",
"",
],
[
"assets/images/cat__mesh.jpg",
"assets/images/tiger.jpg",
"a photo of a tiger standing on snow",
"a 3D mesh of a cat",
"",
],
[
"assets/images/dog__sketch.jpg",
"assets/images/squirrel.jpg",
"a photo of a squirrel",
"a sketch of a dog",
"",
],
[
"assets/images/living_room__seg.jpg",
"assets/images/van_gogh.jpg",
"a Van Gogh painting of a living room",
"a segmentation map of a living room",
"",
],
[
"assets/images/bedroom__sketch.jpg",
"assets/images/living_room_modern.jpg",
"a sketch of a bedroom",
"a photo of a modern bedroom during sunset",
"",
],
[
"assets/images/running__pose.jpg",
"assets/images/man_park.jpg",
"a photo of a man running in a park",
"a pose image of a person running",
"",
],
[
"assets/images/fruit_bowl.jpg",
"assets/images/grapes.jpg",
"a photo of a bowl of grapes in the trees",
"a photo of a bowl of fruits",
"",
],
[
"assets/images/bear_avocado__spatext.jpg",
None,
"a realistic photo of a bear and an avocado in a forest",
"a segmentation map of a bear and an avocado",
"",
],
[
"assets/images/cat__point_cloud.jpg",
None,
"an embroidery of a white cat sitting on a rock under the night sky",
"a 3D point cloud of a cat",
"",
],
[
"assets/images/library__mesh.jpg",
None,
"a Polaroid photo of an old library, sunlight streaming in",
"a 3D mesh of a library",
"",
],
[
"assets/images/knight__humanoid.jpg",
None,
"a photo of a medieval soldier standing on a barren field, raining",
"a 3D model of a person holding a sword and shield",
"",
],
[
"assets/images/person__mesh.jpg",
None,
"a photo of a Karate man performing in a cyberpunk city at night",
"a 3D mesh of a person",
"",
],
],
[
structure_image,
appearance_image,
prompt,
structure_prompt,
appearance_prompt,
],
examples_per_page=50,
cache_examples="lazy",
fn=inference,
outputs=[result, result_refiner, structure_recon, appearance_recon]
)
app.launch(debug=False, share=False)