Spaces:
Running
on
Zero
A newer version of the Gradio SDK is available:
5.9.1
Text-to-image
text-to-image ํ์ธํ๋ ์คํฌ๋ฆฝํธ๋ experimental ์ํ์ ๋๋ค. ๊ณผ์ ํฉํ๊ธฐ ์ฝ๊ณ ์น๋ช ์ ์ธ ๋ง๊ฐ๊ณผ ๊ฐ์ ๋ฌธ์ ์ ๋ถ๋ชํ๊ธฐ ์ฝ์ต๋๋ค. ์์ฒด ๋ฐ์ดํฐ์ ์์ ์ต์์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ผ๋ ค๋ฉด ๋ค์ํ ํ์ดํผํ๋ผ๋ฏธํฐ๋ฅผ ํ์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
Stable Diffusion๊ณผ ๊ฐ์ text-to-image ๋ชจ๋ธ์ ํ
์คํธ ํ๋กฌํํธ์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค. ์ด ๊ฐ์ด๋๋ PyTorch ๋ฐ Flax๋ฅผ ์ฌ์ฉํ์ฌ ์์ฒด ๋ฐ์ดํฐ์
์์ CompVis/stable-diffusion-v1-4
๋ชจ๋ธ๋ก ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. ์ด ๊ฐ์ด๋์ ์ฌ์ฉ๋ text-to-image ํ์ธํ๋์ ์ํ ๋ชจ๋ ํ์ต ์คํฌ๋ฆฝํธ์ ๊ด์ฌ์ด ์๋ ๊ฒฝ์ฐ ์ด ๋ฆฌํฌ์งํ ๋ฆฌ์์ ์์ธํ ์ฐพ์ ์ ์์ต๋๋ค.
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์, ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํ์ต dependency๋ค์ ์ค์นํด์ผ ํฉ๋๋ค:
pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt
๊ทธ๋ฆฌ๊ณ ๐คAccelerate ํ๊ฒฝ์ ์ด๊ธฐํํฉ๋๋ค:
accelerate config
๋ฆฌํฌ์งํ ๋ฆฌ๋ฅผ ์ด๋ฏธ ๋ณต์ ํ ๊ฒฝ์ฐ, ์ด ๋จ๊ณ๋ฅผ ์ํํ ํ์๊ฐ ์์ต๋๋ค. ๋์ , ๋ก์ปฌ ์ฒดํฌ์์ ๊ฒฝ๋ก๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช ์ํ ์ ์์ผ๋ฉฐ ๊ฑฐ๊ธฐ์์ ๋ก๋๋ฉ๋๋ค.
ํ๋์จ์ด ์๊ตฌ ์ฌํญ
gradient_checkpointing
๋ฐ mixed_precision
์ ์ฌ์ฉํ๋ฉด ๋จ์ผ 24GB GPU์์ ๋ชจ๋ธ์ ํ์ธํ๋ํ ์ ์์ต๋๋ค. ๋ ๋์ batch_size
์ ๋ ๋น ๋ฅธ ํ๋ จ์ ์ํด์๋ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ต๋๋ค. TPU ๋๋ GPU์์ ํ์ธํ๋์ ์ํด JAX๋ Flax๋ฅผ ์ฌ์ฉํ ์๋ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ ์๋๋ฅผ ์ฐธ์กฐํ์ธ์.
xFormers๋ก memory efficient attention์ ํ์ฑํํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ํจ์ฌ ๋ ์ค์ผ ์ ์์ต๋๋ค. xFormers๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ๊ณ --enable_xformers_memory_efficient_attention
๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช
์ํฉ๋๋ค.
xFormers๋ Flax์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
Hub์ ๋ชจ๋ธ ์ ๋ก๋ํ๊ธฐ
ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ์ถ๊ฐํ์ฌ ๋ชจ๋ธ์ ํ๋ธ์ ์ ์ฅํฉ๋๋ค:
--push_to_hub
์ฒดํฌํฌ์ธํธ ์ ์ฅ ๋ฐ ๋ถ๋ฌ์ค๊ธฐ
ํ์ต ์ค ๋ฐ์ํ ์ ์๋ ์ผ์ ๋๋นํ์ฌ ์ ๊ธฐ์ ์ผ๋ก ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํด ๋๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๋ ค๋ฉด ํ์ต ์คํฌ๋ฆฝํธ์ ๋ค์ ์ธ์๋ฅผ ๋ช ์ํฉ๋๋ค.
--checkpointing_steps=500
500์คํ ๋ง๋ค ์ ์ฒด ํ์ต state๊ฐ 'output_dir'์ ํ์ ํด๋์ ์ ์ฅ๋ฉ๋๋ค. ์ฒดํฌํฌ์ธํธ๋ 'checkpoint-'์ ์ง๊ธ๊น์ง ํ์ต๋ step ์์ ๋๋ค. ์๋ฅผ ๋ค์ด 'checkpoint-1500'์ 1500 ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์ ๋๋ค.
ํ์ต์ ์ฌ๊ฐํ๊ธฐ ์ํด ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๋ ค๋ฉด '--resume_from_checkpoint' ์ธ์๋ฅผ ํ์ต ์คํฌ๋ฆฝํธ์ ๋ช ์ํ๊ณ ์ฌ๊ฐํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ง์ ํ์ญ์์ค. ์๋ฅผ ๋ค์ด ๋ค์ ์ธ์๋ 1500๊ฐ์ ํ์ต step ํ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ์์๋ถํฐ ํ๋ จ์ ์ฌ๊ฐํฉ๋๋ค.
--resume_from_checkpoint="checkpoint-1500"
ํ์ธํ๋
๋ค์๊ณผ ๊ฐ์ด [Naruto BLIP ์บก์ ](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) ๋ฐ์ดํฐ์ ์์ ํ์ธํ๋ ์คํ์ ์ํด [PyTorch ํ์ต ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)๋ฅผ ์คํํฉ๋๋ค:export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/naruto-blip-captions"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-naruto-model"
์์ฒด ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค Datasets์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์ ์ ์ค๋นํ์ธ์. ๋ฐ์ดํฐ์ ์ ํ๋ธ์ ์ ๋ก๋ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค.
์ฌ์ฉ์ ์ปค์คํ
loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ TRAIN_DIR
์ ๋ก์ปฌ ๋ฐ์ดํฐ์
์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ๊ณผ OUTPUT_DIR
์์ ๋ชจ๋ธ์ ์ ์ฅํ ์์น๋ฅผ ๋ณด์ฌ์ค๋๋ค:
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
export OUTPUT_DIR="path_to_save_model"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR}
[@duongna211](https://github.com/duongna21)์ ๊ธฐ์ฌ๋ก, Flax๋ฅผ ์ฌ์ฉํด TPU ๋ฐ GPU์์ Stable Diffusion ๋ชจ๋ธ์ ๋ ๋น ๋ฅด๊ฒ ํ์ตํ ์ ์์ต๋๋ค. ์ด๋ TPU ํ๋์จ์ด์์ ๋งค์ฐ ํจ์จ์ ์ด์ง๋ง GPU์์๋ ํ๋ฅญํ๊ฒ ์๋ํฉ๋๋ค. Flax ํ์ต ์คํฌ๋ฆฝํธ๋ gradient checkpointing๋ gradient accumulation๊ณผ ๊ฐ์ ๊ธฐ๋ฅ์ ์์ง ์ง์ํ์ง ์์ผ๋ฏ๋ก ๋ฉ๋ชจ๋ฆฌ๊ฐ 30GB ์ด์์ธ GPU ๋๋ TPU v3๊ฐ ํ์ํฉ๋๋ค.
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์ ์๊ตฌ ์ฌํญ์ด ์ค์น๋์ด ์๋์ง ํ์ธํ์ญ์์ค:
pip install -U -r requirements_flax.txt
๊ทธ๋ฌ๋ฉด ๋ค์๊ณผ ๊ฐ์ด Flax ํ์ต ์คํฌ๋ฆฝํธ๋ฅผ ์คํํ ์ ์์ต๋๋ค.
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export dataset_name="lambdalabs/naruto-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-naruto-model"
์์ฒด ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ธํ๋ํ๋ ค๋ฉด ๐ค Datasets์์ ์๊ตฌํ๋ ํ์์ ๋ฐ๋ผ ๋ฐ์ดํฐ์ ์ ์ค๋นํ์ธ์. ๋ฐ์ดํฐ์ ์ ํ๋ธ์ ์ ๋ก๋ํ๊ฑฐ๋ [ํ์ผ๋ค์ด ์๋ ๋ก์ปฌ ํด๋๋ฅผ ์ค๋น](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ํ ์ ์์ต๋๋ค.
์ฌ์ฉ์ ์ปค์คํ
loading logic์ ์ฌ์ฉํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ๋ฅผ ์์ ํ์ญ์์ค. ๋์์ด ๋๋๋ก ์ฝ๋์ ์ ์ ํ ์์น์ ํฌ์ธํฐ๋ฅผ ๋จ๊ฒผ์ต๋๋ค. ๐ค ์๋ ์์ ์คํฌ๋ฆฝํธ๋ TRAIN_DIR
์ ๋ก์ปฌ ๋ฐ์ดํฐ์
์ผ๋ก๋ฅผ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค:
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-naruto-model"
LoRA
Text-to-image ๋ชจ๋ธ ํ์ธํ๋์ ์ํด, ๋๊ท๋ชจ ๋ชจ๋ธ ํ์ต์ ๊ฐ์ํํ๊ธฐ ์ํ ํ์ธํ๋ ๊ธฐ์ ์ธ LoRA(Low-Rank Adaptation of Large Language Models)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ LoRA ํ์ต ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์.
์ถ๋ก
ํ๋ธ์ ๋ชจ๋ธ ๊ฒฝ๋ก ๋๋ ๋ชจ๋ธ ์ด๋ฆ์ [StableDiffusionPipeline
]์ ์ ๋ฌํ์ฌ ์ถ๋ก ์ ์ํด ํ์ธ ํ๋๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค:
model_path = "path_to_saved_model" pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) pipe.to("cuda")
image = pipe(prompt="yoda").images[0] image.save("yoda-naruto.png")
</pt>
<jax>
```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
model_path = "path_to_saved_model"
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
prompt = "yoda naruto"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
image.save("yoda-naruto.png")