Diffusers documentation

DreamBooth

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

DreamBooth

DreamBooth is a training technique that updates the entire diffusion model by training on just a few images of a subject or style. It works by associating a special word in the prompt with the example images.

If you’re training on a GPU with limited vRAM, you should try enabling the gradient_checkpointing and mixed_precision parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with xFormers. JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn’t support gradient checkpointing or xFormers. You should have a GPU with >30GB of memory if you want to train faster with Flax.

This guide will explore the train_dreambooth.py script to help you become more familiar with it, and how you can adapt it for your own use-case.

Before running the script, make sure you install the library from source:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

Navigate to the example folder with the training script and install the required dependencies for the script you’re using:

PyTorch
Flax
cd examples/dreambooth
pip install -r requirements.txt

🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It’ll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate Quick tour to learn more.

Initialize an 🤗 Accelerate environment:

accelerate config

To setup a default 🤗 Accelerate environment without choosing any configurations:

accelerate config default

Or if your environment doesn’t support an interactive shell, like a notebook, you can use:

from accelerate.utils import write_basic_config

write_basic_config()

Lastly, if you want to train a model on your own dataset, take a look at the Create a dataset for training guide to learn how to create a dataset that works with the training script.

The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn’t cover every aspect of the script in detail. If you’re interested in learning more, feel free to read through the script and let us know if you have any questions or concerns.

Script parameters

DreamBooth is very sensitive to training hyperparameters, and it is easy to overfit. Read the Training Stable Diffusion with Dreambooth using 🧨 Diffusers blog post for recommended settings for different subjects to help you choose the appropriate hyperparameters.

The training script offers many parameters for customizing your training run. All of the parameters and their descriptions are found in the parse_args() function. The parameters are set with default values that should work pretty well out-of-the-box, but you can also set your own values in the training command if you’d like.

For example, to train in the bf16 format:

accelerate launch train_dreambooth.py \
    --mixed_precision="bf16"

Some basic and important parameters to know and specify are:

  • --pretrained_model_name_or_path: the name of the model on the Hub or a local path to the pretrained model
  • --instance_data_dir: path to a folder containing the training dataset (example images)
  • --instance_prompt: the text prompt that contains the special word for the example images
  • --train_text_encoder: whether to also train the text encoder
  • --output_dir: where to save the trained model
  • --push_to_hub: whether to push the trained model to the Hub
  • --checkpointing_steps: frequency of saving a checkpoint as the model trains; this is useful if for some reason training is interrupted, you can continue training from that checkpoint by adding --resume_from_checkpoint to your training command

Min-SNR weighting

The Min-SNR weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting epsilon (noise) or v_prediction, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.

Add the --snr_gamma parameter and set it to the recommended value of 5.0:

accelerate launch train_dreambooth.py \
  --snr_gamma=5.0

Prior preservation loss

Prior preservation loss is a method that uses a model’s own generated samples to help it learn how to generate more diverse images. Because these generated sample images belong to the same class as the images you provided, they help the model retain what it has learned about the class and how it can use what it already knows about the class to make new compositions.

  • --with_prior_preservation: whether to use prior preservation loss
  • --prior_loss_weight: controls the influence of the prior preservation loss on the model
  • --class_data_dir: path to a folder containing the generated class sample images
  • --class_prompt: the text prompt describing the class of the generated sample images
accelerate launch train_dreambooth.py \
  --with_prior_preservation \
  --prior_loss_weight=1.0 \
  --class_data_dir="path/to/class/images" \
  --class_prompt="text prompt describing class"

Train text encoder

To improve the quality of the generated outputs, you can also train the text encoder in addition to the UNet. This requires additional memory and you’ll need a GPU with at least 24GB of vRAM. If you have the necessary hardware, then training the text encoder produces better results, especially when generating images of faces. Enable this option by:

accelerate launch train_dreambooth.py \
  --train_text_encoder

Training script

DreamBooth comes with its own dataset classes:

  • DreamBoothDataset: preprocesses the images and class images, and tokenizes the prompts for training
  • PromptDataset: generates the prompt embeddings to generate the class images

If you enabled prior preservation loss, the class images are generated here:

sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)

sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)

for example in tqdm(
    sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
    images = pipeline(example["prompt"]).images

Next is the main() function which handles setting up the dataset for training and the training loop itself. The script loads the tokenizer, scheduler and models:

# Load the tokenizer
if args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=args.revision,
        use_fast=False,
    )

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)

if model_has_vae(args):
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
    )
else:
    vae = None

unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

Then, it’s time to create the training dataset and DataLoader from DreamBoothDataset:

train_dataset = DreamBoothDataset(
    instance_data_root=args.instance_data_dir,
    instance_prompt=args.instance_prompt,
    class_data_root=args.class_data_dir if args.with_prior_preservation else None,
    class_prompt=args.class_prompt,
    class_num=args.num_class_images,
    tokenizer=tokenizer,
    size=args.resolution,
    center_crop=args.center_crop,
    encoder_hidden_states=pre_computed_encoder_hidden_states,
    class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
    tokenizer_max_length=args.tokenizer_max_length,
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
    num_workers=args.dataloader_num_workers,
)

Lastly, the training loop takes care of the remaining steps such as converting images to latent space, adding noise to the input, predicting the noise residual, and calculating the loss.

If you want to learn more about how the training loop works, check out the Understanding pipelines, models and schedulers tutorial which breaks down the basic pattern of the denoising process.

Launch the script

You’re now ready to launch the training script! 🚀

For this guide, you’ll download some images of a dog and store them in a directory. But remember, you can create and use your own dataset if you want (see the Create a dataset for training guide).

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir,
    repo_type="dataset",
    ignore_patterns=".gitattributes",
)

Set the environment variable MODEL_NAME to a model id on the Hub or a path to a local model, INSTANCE_DIR to the path where you just downloaded the dog images to, and OUTPUT_DIR to where you want to save the model. You’ll use sks as the special word to tie the training to.

If you’re interested in following along with the training process, you can periodically save generated images as training progresses. Add the following parameters to the training command:

--validation_prompt="a photo of a sks dog"
--num_validation_images=4
--validation_steps=100

One more thing before you launch the script! Depending on the GPU you have, you may need to enable certain optimizations to train DreamBooth.

16GB
12GB
8GB

On a 16GB GPU, you can use bitsandbytes 8-bit optimizer and gradient checkpointing to help you train a DreamBooth model. Install bitsandbytes:

pip install bitsandbytes

Then, add the following parameter to your training command:

accelerate launch train_dreambooth.py \
  --gradient_checkpointing \
  --use_8bit_adam \
PyTorch
Flax
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
export INSTANCE_DIR="./dog"
export OUTPUT_DIR="path_to_saved_model"

accelerate launch train_dreambooth.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-6 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=400 \
  --push_to_hub

Once training is complete, you can use your newly trained model for inference!

Can’t wait to try your model for inference before training is complete? 🤭 Make sure you have the latest version of 🤗 Accelerate installed.

from diffusers import DiffusionPipeline, UNet2DConditionModel
from transformers import CLIPTextModel
import torch

unet = UNet2DConditionModel.from_pretrained("path/to/model/checkpoint-100/unet")

# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
text_encoder = CLIPTextModel.from_pretrained("path/to/model/checkpoint-100/checkpoint-100/text_encoder")

pipeline = DiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, dtype=torch.float16,
).to("cuda")

image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("dog-bucket.png")
PyTorch
Flax
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("path_to_saved_model", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("dog-bucket.png")

LoRA

LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the train_dreambooth_lora.py script to train with LoRA.

The LoRA training script is discussed in more detail in the LoRA training guide.

Stable Diffusion XL

Stable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the train_dreambooth_lora_sdxl.py script to train a SDXL model with LoRA.

The SDXL training script is discussed in more detail in the SDXL training guide.

DeepFloyd IF

DeepFloyd IF is a cascading pixel diffusion model with three stages. The first stage generates a base image and the second and third stages progressively upscales the base image into a high-resolution 1024x1024 image. Use the train_dreambooth_lora.py or train_dreambooth.py scripts to train a DeepFloyd IF model with LoRA or the full model.

DeepFloyd IF uses predicted variance, but the Diffusers training scripts uses predicted error so the trained DeepFloyd IF models are switched to a fixed variance schedule. The training scripts will update the scheduler config of the fully trained model for you. However, when you load the saved LoRA weights you must also update the pipeline’s scheduler config.

from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", use_safetensors=True)

pipe.load_lora_weights("<lora weights path>")

# Update scheduler config to fixed variance schedule
pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small")

The stage 2 model requires additional validation images to upscale. You can download and use a downsized version of the training images for this.

from huggingface_hub import snapshot_download

local_dir = "./dog_downsized"
snapshot_download(
    "diffusers/dog-example-downsized",
    local_dir=local_dir,
    repo_type="dataset",
    ignore_patterns=".gitattributes",
)

The code samples below provide a brief overview of how to train a DeepFloyd IF model with a combination of DreamBooth and LoRA. Some important parameters to note are:

  • --resolution=64, a much smaller resolution is required because DeepFloyd IF is a pixel diffusion model and to work on uncompressed pixels, the input images must be smaller
  • --pre_compute_text_embeddings, compute the text embeddings ahead of time to save memory because the T5Model can take up a lot of memory
  • --tokenizer_max_length=77, you can use a longer default text length with T5 as the text encoder but the default model encoding procedure uses a shorter text length
  • --text_encoder_use_attention_mask, to pass the attention mask to the text encoder
Stage 1 LoRA DreamBooth
Stage 2 LoRA DreamBooth
Stage 1 DreamBooth
Stage 2 DreamBooth

Training stage 1 of DeepFloyd IF with LoRA and DreamBooth requires ~28GB of memory.

export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="dreambooth_dog_lora"

accelerate launch train_dreambooth_lora.py \
  --report_to wandb \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a sks dog" \
  --resolution=64 \
  --train_batch_size=4 \
  --gradient_accumulation_steps=1 \
  --learning_rate=5e-6 \
  --scale_lr \
  --max_train_steps=1200 \
  --validation_prompt="a sks dog" \
  --validation_epochs=25 \
  --checkpointing_steps=100 \
  --pre_compute_text_embeddings \
  --tokenizer_max_length=77 \
  --text_encoder_use_attention_mask

Training tips

Training the DeepFloyd IF model can be challenging, but here are some tips that we’ve found helpful:

  • LoRA is sufficient for training the stage 1 model because the model’s low resolution makes representing finer details difficult regardless.
  • For common or simple objects, you don’t necessarily need to finetune the upscaler. Make sure the prompt passed to the upscaler is adjusted to remove the new token from the instance prompt. For example, if your stage 1 prompt is “a sks dog” then your stage 2 prompt should be “a dog”.
  • For finer details like faces, fully training the stage 2 upscaler is better than training the stage 2 model with LoRA. It also helps to use lower learning rates with larger batch sizes.
  • Lower learning rates should be used to train the stage 2 model.
  • The DDPMScheduler works better than the DPMSolver used in the training scripts.

Next steps

Congratulations on training your DreamBooth model! To learn more about how to use your new model, the following guide may be helpful:

  • Learn how to load a DreamBooth model for inference if you trained your model with LoRA.
< > Update on GitHub