Training on 8 A100 machine

#26
by mashreve - opened

Thank you for sharing this model. I am attempting to train it on my own control using the training script, but I am having OOM issues when using the same machine configuration (8xA100). Could you please provide information re: the accelerate configuration used? Did you have to use a fully sharded approach? Thank you.

I'm not sure if we officially released the training script, which one are you looking at? You can also just keep decreasing the batch size until you don't get ooms, iirc batch size wasn't terribly important

Thanks for the response. I am trying the script mentioned in the readme, which is based on the test circle dataset: https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md.
I am using a batch size of 1 and getting an OOM:

[Update]: I see now that in the reference training script, a single A100 was used... If at all possible, it would be great to get the modifications you made to get it working on 8 GPUs.

#!/bin/bash
export MODEL_DIR="sdxl-vae-fp16-fix"
export OUTPUT_DIR="test_circle_training"
accelerate launch train_controlnet_sdxl.py
--pretrained_model_name_or_path=$MODEL_DIR
--output_dir=$OUTPUT_DIR
--dataset_name=fusing/fill50k
--mixed_precision="fp16"
--resolution=512
--learning_rate=1e-5
--max_train_steps=15000
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png"
--validation_prompt "red circle with blue background" "cyan circle with brown floral background"
--validation_steps=100
--train_batch_size=1
--gradient_accumulation_steps=4
--report_to="wandb"
--seed=42

Sign up or log in to comment