NadaGh's picture
End of training
dde5d93 verified

ControlNet training example for Stable Diffusion XL (SDXL)

The train_controlnet_sdxl.py script shows how to implement the ControlNet training procedure and adapt it for Stable Diffusion XL.

Running locally with PyTorch

Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

Important

To make sure you can successfully run the latest versions of the example scripts, we highly recommend installing from source and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .

Then cd in the examples/controlnet folder and run

pip install -r requirements_sdxl.txt

And initialize an 🤗Accelerate environment with:

accelerate config

Or for a default accelerate configuration without answering questions about your environment

accelerate config default

Or if your environment doesn't support an interactive shell (e.g., a notebook)

from accelerate.utils import write_basic_config
write_basic_config()

When running accelerate config, if we specify torch compile mode to True there can be dramatic speedups.

Circle filling dataset

The original dataset is hosted in the ControlNet repo. We re-uploaded it to be compatible with datasets here. Note that datasets handles dataloading within the training script.

Training

Our training examples use two test conditioning images. They can be downloaded by running

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png

Then run huggingface-cli login to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.

export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="path to save model"

accelerate launch train_controlnet_sdxl.py \
 --pretrained_model_name_or_path=$MODEL_DIR \
 --output_dir=$OUTPUT_DIR \
 --dataset_name=fusing/fill50k \
 --mixed_precision="fp16" \
 --resolution=1024 \
 --learning_rate=1e-5 \
 --max_train_steps=15000 \
 --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
 --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
 --validation_steps=100 \
 --train_batch_size=1 \
 --gradient_accumulation_steps=4 \
 --report_to="wandb" \
 --seed=42 \
 --push_to_hub

To better track our training experiments, we're using the following flags in the command above:

  • report_to="wandb will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install wandb with pip install wandb.
  • validation_image, validation_prompt, and validation_steps to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.

Our experiments were conducted on a single 40GB A100 GPU.

Inference

Once training is done, we can perform inference like so:

from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch

base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
controlnet_path = "path to controlnet"

controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_path, controlnet=controlnet, torch_dtype=torch.float16
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()

control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
prompt = "pale golden rod circle with old lace background"

# generate image
generator = torch.manual_seed(0)
image = pipe(
    prompt, num_inference_steps=20, generator=generator, image=control_image
).images[0]
image.save("./output.png")

Notes

Specifying a better VAE

SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely --pretrained_vae_model_name_or_path that lets you specify the location of an alternative VAE (such as madebyollin/sdxl-vae-fp16-fix).

If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:

+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
+   vae=vae,
)