Spaces:
Running
on
Zero
Running
on
Zero
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
--> | |
# Text-to-image | |
<Tip warning={true}> | |
text-to-image νμΈνλ μ€ν¬λ¦½νΈλ experimental μνμ λλ€. κ³Όμ ν©νκΈ° μ½κ³ μΉλͺ μ μΈ λ§κ°κ³Ό κ°μ λ¬Έμ μ λΆλͺνκΈ° μ½μ΅λλ€. μ체 λ°μ΄ν°μ μμ μ΅μμ κ²°κ³Όλ₯Ό μ»μΌλ €λ©΄ λ€μν νμ΄νΌνλΌλ―Έν°λ₯Ό νμνλ κ²μ΄ μ’μ΅λλ€. | |
</Tip> | |
Stable Diffusionκ³Ό κ°μ text-to-image λͺ¨λΈμ ν μ€νΈ ν둬ννΈμμ μ΄λ―Έμ§λ₯Ό μμ±ν©λλ€. μ΄ κ°μ΄λλ PyTorch λ° Flaxλ₯Ό μ¬μ©νμ¬ μ체 λ°μ΄ν°μ μμ [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) λͺ¨λΈλ‘ νμΈνλνλ λ°©λ²μ 보μ¬μ€λλ€. μ΄ κ°μ΄λμ μ¬μ©λ text-to-image νμΈνλμ μν λͺ¨λ νμ΅ μ€ν¬λ¦½νΈμ κ΄μ¬μ΄ μλ κ²½μ° μ΄ [리ν¬μ§ν 리](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)μμ μμΈν μ°Ύμ μ μμ΅λλ€. | |
μ€ν¬λ¦½νΈλ₯Ό μ€ννκΈ° μ μ, λΌμ΄λΈλ¬λ¦¬μ νμ΅ dependencyλ€μ μ€μΉν΄μΌ ν©λλ€: | |
```bash | |
pip install git+https://github.com/huggingface/diffusers.git | |
pip install -U -r requirements.txt | |
``` | |
κ·Έλ¦¬κ³ [π€Accelerate](https://github.com/huggingface/accelerate/) νκ²½μ μ΄κΈ°νν©λλ€: | |
```bash | |
accelerate config | |
``` | |
리ν¬μ§ν 리λ₯Ό μ΄λ―Έ 볡μ ν κ²½μ°, μ΄ λ¨κ³λ₯Ό μνν νμκ° μμ΅λλ€. λμ , λ‘컬 체ν¬μμ κ²½λ‘λ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μν μ μμΌλ©° κ±°κΈ°μμ λ‘λλ©λλ€. | |
### νλμ¨μ΄ μꡬ μ¬ν | |
`gradient_checkpointing` λ° `mixed_precision`μ μ¬μ©νλ©΄ λ¨μΌ 24GB GPUμμ λͺ¨λΈμ νμΈνλν μ μμ΅λλ€. λ λμ `batch_size`μ λ λΉ λ₯Έ νλ ¨μ μν΄μλ GPU λ©λͺ¨λ¦¬κ° 30GB μ΄μμΈ GPUλ₯Ό μ¬μ©νλ κ²μ΄ μ’μ΅λλ€. TPU λλ GPUμμ νμΈνλμ μν΄ JAXλ Flaxλ₯Ό μ¬μ©ν μλ μμ΅λλ€. μμΈν λ΄μ©μ [μλ](#flax-jax-finetuning)λ₯Ό μ°Έμ‘°νμΈμ. | |
xFormersλ‘ memory efficient attentionμ νμ±ννμ¬ λ©λͺ¨λ¦¬ μ¬μ©λ ν¨μ¬ λ μ€μΌ μ μμ΅λλ€. [xFormersκ° μ€μΉ](./optimization/xformers)λμ΄ μλμ§ νμΈνκ³ `--enable_xformers_memory_efficient_attention`λ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μν©λλ€. | |
xFormersλ Flaxμ μ¬μ©ν μ μμ΅λλ€. | |
## Hubμ λͺ¨λΈ μ λ‘λνκΈ° | |
νμ΅ μ€ν¬λ¦½νΈμ λ€μ μΈμλ₯Ό μΆκ°νμ¬ λͺ¨λΈμ νλΈμ μ μ₯ν©λλ€: | |
```bash | |
--push_to_hub | |
``` | |
## 체ν¬ν¬μΈνΈ μ μ₯ λ° λΆλ¬μ€κΈ° | |
νμ΅ μ€ λ°μν μ μλ μΌμ λλΉνμ¬ μ κΈ°μ μΌλ‘ 체ν¬ν¬μΈνΈλ₯Ό μ μ₯ν΄ λλ κ²μ΄ μ’μ΅λλ€. 체ν¬ν¬μΈνΈλ₯Ό μ μ₯νλ €λ©΄ νμ΅ μ€ν¬λ¦½νΈμ λ€μ μΈμλ₯Ό λͺ μν©λλ€. | |
```bash | |
--checkpointing_steps=500 | |
``` | |
500μ€ν λ§λ€ μ 체 νμ΅ stateκ° 'output_dir'μ νμ ν΄λμ μ μ₯λ©λλ€. 체ν¬ν¬μΈνΈλ 'checkpoint-'μ μ§κΈκΉμ§ νμ΅λ step μμ λλ€. μλ₯Ό λ€μ΄ 'checkpoint-1500'μ 1500 νμ΅ step νμ μ μ₯λ 체ν¬ν¬μΈνΈμ λλ€. | |
νμ΅μ μ¬κ°νκΈ° μν΄ μ²΄ν¬ν¬μΈνΈλ₯Ό λΆλ¬μ€λ €λ©΄ '--resume_from_checkpoint' μΈμλ₯Ό νμ΅ μ€ν¬λ¦½νΈμ λͺ μνκ³ μ¬κ°ν 체ν¬ν¬μΈνΈλ₯Ό μ§μ νμμμ€. μλ₯Ό λ€μ΄ λ€μ μΈμλ 1500κ°μ νμ΅ step νμ μ μ₯λ 체ν¬ν¬μΈνΈμμλΆν° νλ ¨μ μ¬κ°ν©λλ€. | |
```bash | |
--resume_from_checkpoint="checkpoint-1500" | |
``` | |
## νμΈνλ | |
<frameworkcontent> | |
<pt> | |
λ€μκ³Ό κ°μ΄ [Naruto BLIP μΊ‘μ ](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) λ°μ΄ν°μ μμ νμΈνλ μ€νμ μν΄ [PyTorch νμ΅ μ€ν¬λ¦½νΈ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)λ₯Ό μ€νν©λλ€: | |
```bash | |
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | |
export dataset_name="lambdalabs/naruto-blip-captions" | |
accelerate launch train_text_to_image.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--dataset_name=$dataset_name \ | |
--use_ema \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--gradient_accumulation_steps=4 \ | |
--gradient_checkpointing \ | |
--mixed_precision="fp16" \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--lr_scheduler="constant" --lr_warmup_steps=0 \ | |
--output_dir="sd-naruto-model" | |
``` | |
μ체 λ°μ΄ν°μ μΌλ‘ νμΈνλνλ €λ©΄ π€ [Datasets](https://huggingface.co/docs/datasets/index)μμ μꡬνλ νμμ λ°λΌ λ°μ΄ν°μ μ μ€λΉνμΈμ. [λ°μ΄ν°μ μ νλΈμ μ λ‘λ](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)νκ±°λ [νμΌλ€μ΄ μλ λ‘컬 ν΄λλ₯Ό μ€λΉ](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν μ μμ΅λλ€. | |
μ¬μ©μ 컀μ€ν loading logicμ μ¬μ©νλ €λ©΄ μ€ν¬λ¦½νΈλ₯Ό μμ νμμμ€. λμμ΄ λλλ‘ μ½λμ μ μ ν μμΉμ ν¬μΈν°λ₯Ό λ¨κ²Όμ΅λλ€. π€ μλ μμ μ€ν¬λ¦½νΈλ `TRAIN_DIR`μ λ‘컬 λ°μ΄ν°μ μΌλ‘λ₯Ό νμΈνλνλ λ°©λ²κ³Ό `OUTPUT_DIR`μμ λͺ¨λΈμ μ μ₯ν μμΉλ₯Ό 보μ¬μ€λλ€: | |
```bash | |
export MODEL_NAME="CompVis/stable-diffusion-v1-4" | |
export TRAIN_DIR="path_to_your_dataset" | |
export OUTPUT_DIR="path_to_save_model" | |
accelerate launch train_text_to_image.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--train_data_dir=$TRAIN_DIR \ | |
--use_ema \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--gradient_accumulation_steps=4 \ | |
--gradient_checkpointing \ | |
--mixed_precision="fp16" \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--lr_scheduler="constant" --lr_warmup_steps=0 \ | |
--output_dir=${OUTPUT_DIR} | |
``` | |
</pt> | |
<jax> | |
[@duongna211](https://github.com/duongna21)μ κΈ°μ¬λ‘, Flaxλ₯Ό μ¬μ©ν΄ TPU λ° GPUμμ Stable Diffusion λͺ¨λΈμ λ λΉ λ₯΄κ² νμ΅ν μ μμ΅λλ€. μ΄λ TPU νλμ¨μ΄μμ λ§€μ° ν¨μ¨μ μ΄μ§λ§ GPUμμλ νλ₯νκ² μλν©λλ€. Flax νμ΅ μ€ν¬λ¦½νΈλ gradient checkpointingλ gradient accumulationκ³Ό κ°μ κΈ°λ₯μ μμ§ μ§μνμ§ μμΌλ―λ‘ λ©λͺ¨λ¦¬κ° 30GB μ΄μμΈ GPU λλ TPU v3κ° νμν©λλ€. | |
μ€ν¬λ¦½νΈλ₯Ό μ€ννκΈ° μ μ μꡬ μ¬νμ΄ μ€μΉλμ΄ μλμ§ νμΈνμμμ€: | |
```bash | |
pip install -U -r requirements_flax.txt | |
``` | |
κ·Έλ¬λ©΄ λ€μκ³Ό κ°μ΄ [Flax νμ΅ μ€ν¬λ¦½νΈ](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py)λ₯Ό μ€νν μ μμ΅λλ€. | |
```bash | |
export MODEL_NAME="runwayml/stable-diffusion-v1-5" | |
export dataset_name="lambdalabs/naruto-blip-captions" | |
python train_text_to_image_flax.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--dataset_name=$dataset_name \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--output_dir="sd-naruto-model" | |
``` | |
μ체 λ°μ΄ν°μ μΌλ‘ νμΈνλνλ €λ©΄ π€ [Datasets](https://huggingface.co/docs/datasets/index)μμ μꡬνλ νμμ λ°λΌ λ°μ΄ν°μ μ μ€λΉνμΈμ. [λ°μ΄ν°μ μ νλΈμ μ λ‘λ](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub)νκ±°λ [νμΌλ€μ΄ μλ λ‘컬 ν΄λλ₯Ό μ€λΉ](https ://huggingface.co/docs/datasets/image_dataset#imagefolder)ν μ μμ΅λλ€. | |
μ¬μ©μ 컀μ€ν loading logicμ μ¬μ©νλ €λ©΄ μ€ν¬λ¦½νΈλ₯Ό μμ νμμμ€. λμμ΄ λλλ‘ μ½λμ μ μ ν μμΉμ ν¬μΈν°λ₯Ό λ¨κ²Όμ΅λλ€. π€ μλ μμ μ€ν¬λ¦½νΈλ `TRAIN_DIR`μ λ‘컬 λ°μ΄ν°μ μΌλ‘λ₯Ό νμΈνλνλ λ°©λ²μ 보μ¬μ€λλ€: | |
```bash | |
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax" | |
export TRAIN_DIR="path_to_your_dataset" | |
python train_text_to_image_flax.py \ | |
--pretrained_model_name_or_path=$MODEL_NAME \ | |
--train_data_dir=$TRAIN_DIR \ | |
--resolution=512 --center_crop --random_flip \ | |
--train_batch_size=1 \ | |
--mixed_precision="fp16" \ | |
--max_train_steps=15000 \ | |
--learning_rate=1e-05 \ | |
--max_grad_norm=1 \ | |
--output_dir="sd-naruto-model" | |
``` | |
</jax> | |
</frameworkcontent> | |
## LoRA | |
Text-to-image λͺ¨λΈ νμΈνλμ μν΄, λκ·λͺ¨ λͺ¨λΈ νμ΅μ κ°μννκΈ° μν νμΈνλ κΈ°μ μΈ LoRA(Low-Rank Adaptation of Large Language Models)λ₯Ό μ¬μ©ν μ μμ΅λλ€. μμΈν λ΄μ©μ [LoRA νμ΅](lora#text-to-image) κ°μ΄λλ₯Ό μ°Έμ‘°νμΈμ. | |
## μΆλ‘ | |
νλΈμ λͺ¨λΈ κ²½λ‘ λλ λͺ¨λΈ μ΄λ¦μ [`StableDiffusionPipeline`]μ μ λ¬νμ¬ μΆλ‘ μ μν΄ νμΈ νλλ λͺ¨λΈμ λΆλ¬μ¬ μ μμ΅λλ€: | |
<frameworkcontent> | |
<pt> | |
```python | |
from diffusers import StableDiffusionPipeline | |
model_path = "path_to_saved_model" | |
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) | |
pipe.to("cuda") | |
image = pipe(prompt="yoda").images[0] | |
image.save("yoda-naruto.png") | |
``` | |
</pt> | |
<jax> | |
```python | |
import jax | |
import numpy as np | |
from flax.jax_utils import replicate | |
from flax.training.common_utils import shard | |
from diffusers import FlaxStableDiffusionPipeline | |
model_path = "path_to_saved_model" | |
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) | |
prompt = "yoda naruto" | |
prng_seed = jax.random.PRNGKey(0) | |
num_inference_steps = 50 | |
num_samples = jax.device_count() | |
prompt = num_samples * [prompt] | |
prompt_ids = pipeline.prepare_inputs(prompt) | |
# shard inputs and rng | |
params = replicate(params) | |
prng_seed = jax.random.split(prng_seed, jax.device_count()) | |
prompt_ids = shard(prompt_ids) | |
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images | |
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) | |
image.save("yoda-naruto.png") | |
``` | |
</jax> | |
</frameworkcontent> |