BleachNick's picture
upload required packages
87d40d2

A newer version of the Gradio SDK is available: 5.9.1

Upgrade

Text-to-image

text-to-image ํŒŒ์ธํŠœ๋‹ ์Šคํฌ๋ฆฝํŠธ๋Š” experimental ์ƒํƒœ์ž…๋‹ˆ๋‹ค. ๊ณผ์ ํ•ฉํ•˜๊ธฐ ์‰ฝ๊ณ  ์น˜๋ช…์ ์ธ ๋ง๊ฐ๊ณผ ๊ฐ™์€ ๋ฌธ์ œ์— ๋ถ€๋”ชํžˆ๊ธฐ ์‰ฝ์Šต๋‹ˆ๋‹ค. ์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ตœ์ƒ์˜ ๊ฒฐ๊ณผ๋ฅผ ์–ป์œผ๋ ค๋ฉด ๋‹ค์–‘ํ•œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํƒ์ƒ‰ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค.

Stable Diffusion๊ณผ ๊ฐ™์€ text-to-image ๋ชจ๋ธ์€ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ์—์„œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ๋Š” PyTorch ๋ฐ Flax๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์—์„œ CompVis/stable-diffusion-v1-4 ๋ชจ๋ธ๋กœ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ์— ์‚ฌ์šฉ๋œ text-to-image ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•œ ๋ชจ๋“  ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๊ด€์‹ฌ์ด ์žˆ๋Š” ๊ฒฝ์šฐ ์ด ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์—์„œ ์ž์„ธํžˆ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „์—, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ํ•™์Šต dependency๋“ค์„ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt

๊ทธ๋ฆฌ๊ณ  ๐Ÿค—Accelerate ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค:

accelerate config

๋ฆฌํฌ์ง€ํ† ๋ฆฌ๋ฅผ ์ด๋ฏธ ๋ณต์ œํ•œ ๊ฒฝ์šฐ, ์ด ๋‹จ๊ณ„๋ฅผ ์ˆ˜ํ–‰ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ๋Œ€์‹ , ๋กœ์ปฌ ์ฒดํฌ์•„์›ƒ ๊ฒฝ๋กœ๋ฅผ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋ช…์‹œํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ ๊ฑฐ๊ธฐ์—์„œ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.

ํ•˜๋“œ์›จ์–ด ์š”๊ตฌ ์‚ฌํ•ญ

gradient_checkpointing ๋ฐ mixed_precision์„ ์‚ฌ์šฉํ•˜๋ฉด ๋‹จ์ผ 24GB GPU์—์„œ ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋” ๋†’์€ batch_size์™€ ๋” ๋น ๋ฅธ ํ›ˆ๋ จ์„ ์œ„ํ•ด์„œ๋Š” GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ 30GB ์ด์ƒ์ธ GPU๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. TPU ๋˜๋Š” GPU์—์„œ ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•ด JAX๋‚˜ Flax๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์•„๋ž˜๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

xFormers๋กœ memory efficient attention์„ ํ™œ์„ฑํ™”ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ํ›จ์”ฌ ๋” ์ค„์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. xFormers๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๊ณ  --enable_xformers_memory_efficient_attention๋ฅผ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋ช…์‹œํ•ฉ๋‹ˆ๋‹ค.

xFormers๋Š” Flax์— ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.

Hub์— ๋ชจ๋ธ ์—…๋กœ๋“œํ•˜๊ธฐ

ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ๋ชจ๋ธ์„ ํ—ˆ๋ธŒ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค:

  --push_to_hub

์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๋ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

ํ•™์Šต ์ค‘ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ์ผ์— ๋Œ€๋น„ํ•˜์—ฌ ์ •๊ธฐ์ ์œผ๋กœ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•ด ๋‘๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•˜๋ ค๋ฉด ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ˆ˜๋ฅผ ๋ช…์‹œํ•ฉ๋‹ˆ๋‹ค.

  --checkpointing_steps=500

500์Šคํ…๋งˆ๋‹ค ์ „์ฒด ํ•™์Šต state๊ฐ€ 'output_dir'์˜ ํ•˜์œ„ ํด๋”์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋Š” 'checkpoint-'์— ์ง€๊ธˆ๊นŒ์ง€ ํ•™์Šต๋œ step ์ˆ˜์ž…๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด 'checkpoint-1500'์€ 1500 ํ•™์Šต step ํ›„์— ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ์ž…๋‹ˆ๋‹ค.

ํ•™์Šต์„ ์žฌ๊ฐœํ•˜๊ธฐ ์œ„ํ•ด ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๋ ค๋ฉด '--resume_from_checkpoint' ์ธ์ˆ˜๋ฅผ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋ช…์‹œํ•˜๊ณ  ์žฌ๊ฐœํ•  ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ง€์ •ํ•˜์‹ญ์‹œ์˜ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋‹ค์Œ ์ธ์ˆ˜๋Š” 1500๊ฐœ์˜ ํ•™์Šต step ํ›„์— ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ๋ถ€ํ„ฐ ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•ฉ๋‹ˆ๋‹ค.

  --resume_from_checkpoint="checkpoint-1500"

ํŒŒ์ธํŠœ๋‹

๋‹ค์Œ๊ณผ ๊ฐ™์ด [Naruto BLIP ์บก์…˜](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) ๋ฐ์ดํ„ฐ์…‹์—์„œ ํŒŒ์ธํŠœ๋‹ ์‹คํ–‰์„ ์œ„ํ•ด [PyTorch ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/naruto-blip-captions"

accelerate launch train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="sd-naruto-model" 

์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํŒŒ์ธํŠœ๋‹ํ•˜๋ ค๋ฉด ๐Ÿค— Datasets์—์„œ ์š”๊ตฌํ•˜๋Š” ํ˜•์‹์— ๋”ฐ๋ผ ๋ฐ์ดํ„ฐ์…‹์„ ์ค€๋น„ํ•˜์„ธ์š”. ๋ฐ์ดํ„ฐ์…‹์„ ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œํ•˜๊ฑฐ๋‚˜ [ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋กœ์ปฌ ํด๋”๋ฅผ ์ค€๋น„](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์šฉ์ž ์ปค์Šคํ…€ loading logic์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ˆ˜์ •ํ•˜์‹ญ์‹œ์˜ค. ๋„์›€์ด ๋˜๋„๋ก ์ฝ”๋“œ์˜ ์ ์ ˆํ•œ ์œ„์น˜์— ํฌ์ธํ„ฐ๋ฅผ ๋‚จ๊ฒผ์Šต๋‹ˆ๋‹ค. ๐Ÿค— ์•„๋ž˜ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋Š” TRAIN_DIR์˜ ๋กœ์ปฌ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ๋ฅผ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ OUTPUT_DIR์—์„œ ๋ชจ๋ธ์„ ์ €์žฅํ•  ์œ„์น˜๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
export OUTPUT_DIR="path_to_save_model"

accelerate launch train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir=${OUTPUT_DIR}
[@duongna211](https://github.com/duongna21)์˜ ๊ธฐ์—ฌ๋กœ, Flax๋ฅผ ์‚ฌ์šฉํ•ด TPU ๋ฐ GPU์—์„œ Stable Diffusion ๋ชจ๋ธ์„ ๋” ๋น ๋ฅด๊ฒŒ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” TPU ํ•˜๋“œ์›จ์–ด์—์„œ ๋งค์šฐ ํšจ์œจ์ ์ด์ง€๋งŒ GPU์—์„œ๋„ ํ›Œ๋ฅญํ•˜๊ฒŒ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. Flax ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋Š” gradient checkpointing๋‚˜ gradient accumulation๊ณผ ๊ฐ™์€ ๊ธฐ๋Šฅ์„ ์•„์ง ์ง€์›ํ•˜์ง€ ์•Š์œผ๋ฏ€๋กœ ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ 30GB ์ด์ƒ์ธ GPU ๋˜๋Š” TPU v3๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „์— ์š”๊ตฌ ์‚ฌํ•ญ์ด ์„ค์น˜๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์‹ญ์‹œ์˜ค:

pip install -U -r requirements_flax.txt

๊ทธ๋Ÿฌ๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด Flax ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export dataset_name="lambdalabs/naruto-blip-captions"

python train_text_to_image_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --output_dir="sd-naruto-model" 

์ž์ฒด ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํŒŒ์ธํŠœ๋‹ํ•˜๋ ค๋ฉด ๐Ÿค— Datasets์—์„œ ์š”๊ตฌํ•˜๋Š” ํ˜•์‹์— ๋”ฐ๋ผ ๋ฐ์ดํ„ฐ์…‹์„ ์ค€๋น„ํ•˜์„ธ์š”. ๋ฐ์ดํ„ฐ์…‹์„ ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œํ•˜๊ฑฐ๋‚˜ [ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋กœ์ปฌ ํด๋”๋ฅผ ์ค€๋น„](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์šฉ์ž ์ปค์Šคํ…€ loading logic์„ ์‚ฌ์šฉํ•˜๋ ค๋ฉด ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ˆ˜์ •ํ•˜์‹ญ์‹œ์˜ค. ๋„์›€์ด ๋˜๋„๋ก ์ฝ”๋“œ์˜ ์ ์ ˆํ•œ ์œ„์น˜์— ํฌ์ธํ„ฐ๋ฅผ ๋‚จ๊ฒผ์Šต๋‹ˆ๋‹ค. ๐Ÿค— ์•„๋ž˜ ์˜ˆ์ œ ์Šคํฌ๋ฆฝํŠธ๋Š” TRAIN_DIR์˜ ๋กœ์ปฌ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ๋ฅผ ํŒŒ์ธํŠœ๋‹ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค:

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"

python train_text_to_image_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --output_dir="sd-naruto-model"

LoRA

Text-to-image ๋ชจ๋ธ ํŒŒ์ธํŠœ๋‹์„ ์œ„ํ•ด, ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ ํ•™์Šต์„ ๊ฐ€์†ํ™”ํ•˜๊ธฐ ์œ„ํ•œ ํŒŒ์ธํŠœ๋‹ ๊ธฐ์ˆ ์ธ LoRA(Low-Rank Adaptation of Large Language Models)๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ LoRA ํ•™์Šต ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

์ถ”๋ก 

ํ—ˆ๋ธŒ์˜ ๋ชจ๋ธ ๊ฒฝ๋กœ ๋˜๋Š” ๋ชจ๋ธ ์ด๋ฆ„์„ [StableDiffusionPipeline]์— ์ „๋‹ฌํ•˜์—ฌ ์ถ”๋ก ์„ ์œ„ํ•ด ํŒŒ์ธ ํŠœ๋‹๋œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

```python from diffusers import StableDiffusionPipeline

model_path = "path_to_saved_model" pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) pipe.to("cuda")

image = pipe(prompt="yoda").images[0] image.save("yoda-naruto.png")

</pt>
<jax>
```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline

model_path = "path_to_saved_model"
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)

prompt = "yoda naruto"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
image.save("yoda-naruto.png")