Spaces:
Running
on
Zero
A newer version of the Gradio SDK is available:
5.9.1
ํ์ดํ๋ผ์ธ, ๋ชจ๋ธ, ์ค์ผ์ค๋ฌ ๋ถ๋ฌ์ค๊ธฐ
๊ธฐ๋ณธ์ ์ผ๋ก diffusion ๋ชจ๋ธ์ ๋ค์ํ ์ปดํฌ๋ํธ๋ค(๋ชจ๋ธ, ํ ํฌ๋์ด์ , ์ค์ผ์ค๋ฌ) ๊ฐ์ ๋ณต์กํ ์ํธ์์ฉ์ ๊ธฐ๋ฐ์ผ๋ก ๋์ํฉ๋๋ค. ๋ํจ์ ์ค(Diffusers)๋ ์ด๋ฌํ diffusion ๋ชจ๋ธ์ ๋ณด๋ค ์ฝ๊ณ ๊ฐํธํ API๋ก ์ ๊ณตํ๋ ๊ฒ์ ๋ชฉํ๋ก ์ค๊ณ๋์์ต๋๋ค. [DiffusionPipeline
]์ diffusion ๋ชจ๋ธ์ด ๊ฐ๋ ๋ณต์ก์ฑ์ ํ๋์ ํ์ดํ๋ผ์ธ API๋ก ํตํฉํ๊ณ , ๋์์ ์ด๋ฅผ ๊ตฌ์ฑํ๋ ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ ํ์คํฌ์ ๋ง์ถฐ ์ ์ฐํ๊ฒ ์ปค์คํฐ๋ง์ด์งํ ์ ์๋๋ก ์ง์ํ๊ณ ์์ต๋๋ค.
diffusion ๋ชจ๋ธ์ ํ๋ จ๊ณผ ์ถ๋ก ์ ํ์ํ ๋ชจ๋ ๊ฒ์ [DiffusionPipeline.from_pretrained
] ๋ฉ์๋๋ฅผ ํตํด ์ ๊ทผํ ์ ์์ต๋๋ค. (์ด ๋ง์ ์๋ฏธ๋ ๋ค์ ๋จ๋ฝ์์ ๋ณด๋ค ์์ธํ๊ฒ ๋ค๋ค๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.)
์ด ๋ฌธ์์์๋ ์ค๋ช ํ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
ํ๋ธ๋ฅผ ํตํด ํน์ ๋ก์ปฌ๋ก ํ์ดํ๋ผ์ธ์ ๋ถ๋ฌ์ค๋ ๋ฒ
ํ์ดํ๋ผ์ธ์ ๋ค๋ฅธ ์ปดํฌ๋ํธ๋ค์ ์ ์ฉํ๋ ๋ฒ
์ค๋ฆฌ์ง๋ ์ฒดํฌํฌ์ธํธ๊ฐ ์๋ variant๋ฅผ ๋ถ๋ฌ์ค๋ ๋ฒ (variant๋ ๊ธฐ๋ณธ์ผ๋ก ์ค์ ๋
fp32
๊ฐ ์๋ ๋ค๋ฅธ ๋ถ๋ ์์์ ํ์ (์:fp16
)์ ์ฌ์ฉํ๊ฑฐ๋ Non-EMA ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ๋ค์ ์๋ฏธํฉ๋๋ค.)๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ถ๋ฌ์ค๋ ๋ฒ
Diffusion ํ์ดํ๋ผ์ธ
๐ก [DiffusionPipeline
] ํด๋์ค๊ฐ ๋์ํ๋ ๋ฐฉ์์ ๋ณด๋ค ์์ธํ ๋ด์ฉ์ด ๊ถ๊ธํ๋ค๋ฉด, DiffusionPipeline explained ์น์
์ ํ์ธํด๋ณด์ธ์.
[DiffusionPipeline
] ํด๋์ค๋ diffusion ๋ชจ๋ธ์ ํ๋ธ๋ก๋ถํฐ ๋ถ๋ฌ์ค๋ ๊ฐ์ฅ ์ฌํํ๋ฉด์ ๋ณดํธ์ ์ธ ๋ฐฉ์์
๋๋ค. [DiffusionPipeline.from_pretrained
] ๋ฉ์๋๋ ์ ํฉํ ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ์๋์ผ๋ก ํ์งํ๊ณ , ํ์ํ ๊ตฌ์ฑ์์(configuration)์ ๊ฐ์ค์น(weight) ํ์ผ๋ค์ ๋ค์ด๋ก๋ํ๊ณ ์บ์ฑํ ๋ค์, ํด๋น ํ์ดํ๋ผ์ธ ์ธ์คํด์ค๋ฅผ ๋ฐํํฉ๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id)
๋ฌผ๋ก [DiffusionPipeline
] ํด๋์ค๋ฅผ ์ฌ์ฉํ์ง ์๊ณ , ๋ช
์์ ์ผ๋ก ์ง์ ํด๋น ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ๋ถ๋ฌ์ค๋ ๊ฒ๋ ๊ฐ๋ฅํฉ๋๋ค. ์๋ ์์ ์ฝ๋๋ ์ ์์์ ๋์ผํ ์ธ์คํด์ค๋ฅผ ๋ฐํํฉ๋๋ค.
from diffusers import StableDiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(repo_id)
CompVis/stable-diffusion-v1-4์ด๋ runwayml/stable-diffusion-v1-5 ๊ฐ์ ์ฒดํฌํฌ์ธํธ๋ค์ ๊ฒฝ์ฐ, ํ๋ ์ด์์ ๋ค์ํ ํ์คํฌ์ ํ์ฉ๋ ์ ์์ต๋๋ค. (์๋ฅผ ๋ค์ด ์์ ๋ ์ฒดํฌํฌ์ธํธ์ ๊ฒฝ์ฐ, text-to-image์ image-to-image์ ๋ชจ๋ ํ์ฉ๋ ์ ์์ต๋๋ค.) ๋ง์ฝ ์ด๋ฌํ ์ฒดํฌํฌ์ธํธ๋ค์ ๊ธฐ๋ณธ ์ค์ ํ์คํฌ๊ฐ ์๋ ๋ค๋ฅธ ํ์คํฌ์ ํ์ฉํ๊ณ ์ ํ๋ค๋ฉด, ํด๋น ํ์คํฌ์ ๋์๋๋ ํ์ดํ๋ผ์ธ(task-specific pipeline)์ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
from diffusers import StableDiffusionImg2ImgPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
๋ก์ปฌ ํ์ดํ๋ผ์ธ
ํ์ดํ๋ผ์ธ์ ๋ก์ปฌ๋ก ๋ถ๋ฌ์ค๊ณ ์ ํ๋ค๋ฉด, git-lfs
๋ฅผ ์ฌ์ฉํ์ฌ ์ง์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ก์ปฌ ๋์คํฌ์ ๋ค์ด๋ก๋ ๋ฐ์์ผ ํฉ๋๋ค. ์๋์ ๋ช
๋ น์ด๋ฅผ ์คํํ๋ฉด ./stable-diffusion-v1-5
๋ ์ด๋ฆ์ผ๋ก ํด๋๊ฐ ๋ก์ปฌ๋์คํฌ์ ์์ฑ๋ฉ๋๋ค.
git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
๊ทธ๋ฐ ๋ค์ ํด๋น ๋ก์ปฌ ๊ฒฝ๋ก๋ฅผ [~DiffusionPipeline.from_pretrained
] ๋ฉ์๋์ ์ ๋ฌํฉ๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "./stable-diffusion-v1-5"
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
์์ ์์์ฝ๋์ฒ๋ผ ๋ง์ฝ repo_id
๊ฐ ๋ก์ปฌ ํจ์ค(local path)๋ผ๋ฉด, [~DiffusionPipeline.from_pretrained
] ๋ฉ์๋๋ ์ด๋ฅผ ์๋์ผ๋ก ๊ฐ์งํ์ฌ ํ๋ธ์์ ํ์ผ์ ๋ค์ด๋ก๋ํ์ง ์์ต๋๋ค. ๋ง์ฝ ๋ก์ปฌ ๋์คํฌ์ ์ ์ฅ๋ ํ์ดํ๋ผ์ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์ต์ ๋ฒ์ ์ด ์๋ ๊ฒฝ์ฐ์๋, ์ต์ ๋ฒ์ ์ ๋ค์ด๋ก๋ํ์ง ์๊ณ ๊ธฐ์กด ๋ก์ปฌ ๋์คํฌ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
ํ์ดํ๋ผ์ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ ๊ต์ฒดํ๊ธฐ
ํ์ดํ๋ผ์ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ๋ค์ ํธํ ๊ฐ๋ฅํ ๋ค๋ฅธ ์ปดํฌ๋ํธ๋ก ๊ต์ฒด๋ ์ ์์ต๋๋ค. ์ด์ ๊ฐ์ ์ปดํฌ๋ํธ ๊ต์ฒด๊ฐ ์ค์ํ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ์ด๋ค ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํ ๊ฒ์ธ๊ฐ๋ ์์ฑ์๋์ ์์ฑํ์ง ๊ฐ์ ํธ๋ ์ด๋์คํ๋ฅผ ์ ์ํ๋ ์ค์ํ ์์์ ๋๋ค.
- diffusion ๋ชจ๋ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ๋ค์ ์ผ๋ฐ์ ์ผ๋ก ๊ฐ๊ฐ ๋ ๋ฆฝ์ ์ผ๋ก ํ๋ จ๋๊ธฐ ๋๋ฌธ์, ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ ์ปดํฌ๋ํธ๊ฐ ์๋ค๋ฉด ๊ทธ๊ฑธ๋ก ๊ต์ฒดํ๋ ์์ผ๋ก ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์์ต๋๋ค.
- ํ์ธ ํ๋ ๋จ๊ณ์์๋ ์ผ๋ฐ์ ์ผ๋ก UNet ํน์ ํ ์คํธ ์ธ์ฝ๋์ ๊ฐ์ ์ผ๋ถ ์ปดํฌ๋ํธ๋ค๋ง ํ๋ จํ๊ฒ ๋ฉ๋๋ค.
์ด๋ค ์ค์ผ์ค๋ฌ๋ค์ด ํธํ๊ฐ๋ฅํ์ง๋ compatibles
์์ฑ์ ํตํด ํ์ธํ ์ ์์ต๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
stable_diffusion.scheduler.compatibles
์ด๋ฒ์๋ [SchedulerMixin.from_pretrained
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด์, ๊ธฐ์กด ๊ธฐ๋ณธ ์ค์ผ์ค๋ฌ์๋ [PNDMScheduler
]๋ฅผ ๋ณด๋ค ์ฐ์ํ ์ฑ๋ฅ์ [EulerDiscreteScheduler
]๋ก ๋ฐ๊ฟ๋ด
์๋ค. ์ค์ผ์ค๋ฌ๋ฅผ ๋ก๋ํ ๋๋ subfolder
์ธ์๋ฅผ ํตํด, ํด๋น ํ์ดํ๋ผ์ธ์ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ์ค์ผ์ค๋ฌ์ ๊ดํ ํ์ํด๋๋ฅผ ๋ช
์ํด์ฃผ์ด์ผ ํฉ๋๋ค.
๊ทธ ๋ค์ ์๋กญ๊ฒ ์์ฑํ [EulerDiscreteScheduler
] ์ธ์คํด์ค๋ฅผ [DiffusionPipeline
]์ scheduler
์ธ์์ ์ ๋ฌํฉ๋๋ค.
from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
repo_id = "runwayml/stable-diffusion-v1-5"
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
์ธ์ดํํฐ ์ฒด์ปค
์คํ
์ด๋ธ diffusion๊ณผ ๊ฐ์ diffusion ๋ชจ๋ธ๋ค์ ์ ํดํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์๋ ์์ต๋๋ค. ์ด๋ฅผ ์๋ฐฉํ๊ธฐ ์ํด ๋ํจ์ ์ค๋ ์์ฑ๋ ์ด๋ฏธ์ง์ ์ ํด์ฑ์ ํ๋จํ๋ ์ธ์ดํํฐ ์ฒด์ปค(safety checker) ๊ธฐ๋ฅ์ ์ง์ํ๊ณ ์์ต๋๋ค. ๋ง์ฝ ์ธ์ดํํฐ ์ฒด์ปค์ ์ฌ์ฉ์ ์ํ์ง ์๋๋ค๋ฉด, safety_checker
์ธ์์ None
์ ์ ๋ฌํด์ฃผ์๋ฉด ๋ฉ๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None)
์ปดํฌ๋ํธ ์ฌ์ฌ์ฉ
๋ณต์์ ํ์ดํ๋ผ์ธ์ ๋์ผํ ๋ชจ๋ธ์ด ๋ฐ๋ณต์ ์ผ๋ก ์ฌ์ฉํ๋ค๋ฉด, ๊ตณ์ด ํด๋น ๋ชจ๋ธ์ ๋์ผํ ๊ฐ์ค์น๋ฅผ ์ค๋ณต์ผ๋ก RAM์ ๋ถ๋ฌ์ฌ ํ์๋ ์์ ๊ฒ์
๋๋ค. [~DiffusionPipeline.components
] ์์ฑ์ ํตํด ํ์ดํ๋ผ์ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ๋ค์ ์ฐธ์กฐํ ์ ์๋๋ฐ, ์ด๋ฒ ๋จ๋ฝ์์๋ ์ด๋ฅผ ํตํด ๋์ผํ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ RAM์ ์ค๋ณต์ผ๋ก ๋ถ๋ฌ์ค๋ ๊ฒ์ ๋ฐฉ์งํ๋ ๋ฒ์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค.
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
model_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
components = stable_diffusion_txt2img.components
๊ทธ ๋ค์ ์ ์์ ์ฝ๋์์ ์ ์ธํ components
๋ณ์๋ฅผ ๋ค๋ฅธ ํ์ดํ๋ผ์ธ์ ์ ๋ฌํจ์ผ๋ก์จ, ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ค๋ณต์ผ๋ก RAM์ ๋ก๋ฉํ์ง ์๊ณ , ๋์ผํ ์ปดํฌ๋ํธ๋ฅผ ์ฌ์ฌ์ฉํ ์ ์์ต๋๋ค.
stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
๋ฌผ๋ก ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ ๋ฐ๋ก ๋ฐ๋ก ํ์ดํ๋ผ์ธ์ ์ ๋ฌํ ์๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด stable_diffusion_txt2img
ํ์ดํ๋ผ์ธ ์์ ์ปดํฌ๋ํธ๋ค ๊ฐ์ด๋ฐ์ ์ธ์ดํํฐ ์ฒด์ปค(safety_checker
)์ ํผ์ณ ์ต์คํธ๋ํฐ(feature_extractor
)๋ฅผ ์ ์ธํ ์ปดํฌ๋ํธ๋ค๋ง stable_diffusion_img2img
ํ์ดํ๋ผ์ธ์์ ์ฌ์ฌ์ฉํ๋ ๋ฐฉ์ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
model_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
vae=stable_diffusion_txt2img.vae,
text_encoder=stable_diffusion_txt2img.text_encoder,
tokenizer=stable_diffusion_txt2img.tokenizer,
unet=stable_diffusion_txt2img.unet,
scheduler=stable_diffusion_txt2img.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
Checkpoint variants
Variant๋ ์ผ๋ฐ์ ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ ์ฒดํฌํฌ์ธํธ๋ค์ ์๋ฏธํฉ๋๋ค.
torch.float16
๊ณผ ๊ฐ์ด ์ ๋ฐ๋๋ ๋ ๋ฎ์ง๋ง, ์ฉ๋ ์ญ์ ๋ ์์ ๋ถ๋์์์ ํ์ ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ. (๋ค๋ง ์ด์ ๊ฐ์ variant์ ๊ฒฝ์ฐ, ์ถ๊ฐ์ ์ธ ํ๋ จ๊ณผ CPUํ๊ฒฝ์์์ ๊ตฌ๋์ด ๋ถ๊ฐ๋ฅํฉ๋๋ค.)- Non-EMA ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ. (Non-EMA ๊ฐ์ค์น์ ๊ฒฝ์ฐ, ํ์ธ ํ๋ ๋จ๊ณ์์ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ถ์ฅ๋๋๋ฐ, ์ถ๋ก ๋จ๊ณ์์ ์ฌ์ฉํ์ง ์๋ ๊ฒ์ด ๊ถ์ฅ๋ฉ๋๋ค.)
๐ก ๋ชจ๋ธ ๊ตฌ์กฐ๋ ๋์ผํ์ง๋ง ์๋ก ๋ค๋ฅธ ํ์ต ํ๊ฒฝ์์ ์๋ก ๋ค๋ฅธ ๋ฐ์ดํฐ์
์ผ๋ก ํ์ต๋ ์ฒดํฌํฌ์ธํธ๋ค์ด ์์ ๊ฒฝ์ฐ, ํด๋น ์ฒดํฌํฌ์ธํธ๋ค์ variant ๋จ๊ณ๊ฐ ์๋ ๋ฆฌํฌ์งํ ๋ฆฌ ๋จ๊ณ์์ ๋ถ๋ฆฌ๋์ด ๊ด๋ฆฌ๋์ด์ผ ํฉ๋๋ค. (์ฆ, ํด๋น ์ฒดํฌํฌ์ธํธ๋ค์ ์๋ก ๋ค๋ฅธ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ๋ฐ๋ก ๊ด๋ฆฌ๋์ด์ผ ํฉ๋๋ค. ์์: [stable-diffusion-v1-4
], [stable-diffusion-v1-5
]).
checkpoint type | weight name | argument for loading weights |
---|---|---|
original | diffusion_pytorch_model.bin | |
floating point | diffusion_pytorch_model.fp16.bin | variant , torch_dtype |
non-EMA | diffusion_pytorch_model.non_ema.bin | variant |
variant๋ฅผ ๋ก๋ํ ๋ 2๊ฐ์ ์ค์ํ argument๊ฐ ์์ต๋๋ค.
torch_dtype
์ ๋ถ๋ฌ์ฌ ์ฒดํฌํฌ์ธํธ์ ๋ถ๋์์์ ์ ์ ์ํฉ๋๋ค. ์๋ฅผ ๋ค์ดtorch_dtype=torch.float16
์ ๋ช ์ํจ์ผ๋ก์จ ๊ฐ์ค์น์ ๋ถ๋์์์ ํ์ ์fl16
์ผ๋ก ๋ณํํ ์ ์์ต๋๋ค. (๋ง์ฝ ๋ฐ๋ก ์ค์ ํ์ง ์์ ๊ฒฝ์ฐ, ๊ธฐ๋ณธ๊ฐ์ผ๋กfp32
ํ์ ์ ๊ฐ์ค์น๊ฐ ๋ก๋ฉ๋ฉ๋๋ค.) ๋ํvariant
์ธ์๋ฅผ ๋ช ์ํ์ง ์์ ์ฑ๋ก ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์จ ๋ค์, ํด๋น ์ฒดํฌํฌ์ธํธ๋ฅผtorch_dtype=torch.float16
์ธ์๋ฅผ ํตํดfp16
ํ์ ์ผ๋ก ๋ณํํ๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค. ์ด ๊ฒฝ์ฐ ๊ธฐ๋ณธ์ผ๋ก ์ค์ ๋fp32
๊ฐ์ค์น๊ฐ ๋จผ์ ๋ค์ด๋ก๋๋๊ณ , ํด๋น ๊ฐ์ค์น๋ค์ ๋ถ๋ฌ์จ ๋ค์fp16
ํ์ ์ผ๋ก ๋ณํํ๊ฒ ๋ฉ๋๋ค.variant
์ธ์๋ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ์ด๋ค variant๋ฅผ ๋ถ๋ฌ์ฌ ๊ฒ์ธ๊ฐ๋ฅผ ์ ์ํฉ๋๋ค. ๊ฐ๋ นdiffusers/stable-diffusion-variants
๋ฆฌํฌ์งํ ๋ฆฌ๋ก๋ถํฐnon_ema
์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๊ณ ์ ํ๋ค๋ฉด,variant="non_ema"
์ธ์๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค.
from diffusers import DiffusionPipeline
# load fp16 variant
stable_diffusion = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
)
# load non_ema variant
stable_diffusion = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema")
๋ค๋ฅธ ๋ถ๋์์์ ํ์
์ ๊ฐ์ค์น ํน์ non-EMA ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๊ธฐ ์ํด์๋, [DiffusionPipeline.save_pretrained
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด์ผ ํ๋ฉฐ, ์ด ๋ variant
์ธ์๋ฅผ ๋ช
์ํด์ค์ผ ํฉ๋๋ค. ์๋์ ์ฒดํฌํฌ์ธํธ์ ๋์ผํ ํด๋์ variant๋ฅผ ์ ์ฅํด์ผ ํ๋ฉฐ, ์ด๋ ๊ฒ ํ๋ฉด ๋์ผํ ํด๋์์ ์ค๋ฆฌ์ง๋ ์ฒดํฌํฌ์ธํธ๊ณผ variant๋ฅผ ๋ชจ๋ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค.
from diffusers import DiffusionPipeline
# save as fp16 variant
stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16")
# save as non-ema variant
stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema")
๋ง์ฝ variant๋ฅผ ๊ธฐ์กด ํด๋์ ์ ์ฅํ์ง ์์ ๊ฒฝ์ฐ, variant
์ธ์๋ฅผ ๋ฐ๋์ ๋ช
์ํด์ผ ํฉ๋๋ค. ๊ทธ๋ ๊ฒ ํ์ง ์์ ๊ฒฝ์ฐ ์๋์ ์ค๋ฆฌ์ง๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ ์ ์๊ฒ ๋๊ธฐ ๋๋ฌธ์ ์๋ฌ๊ฐ ๋ฐ์ํฉ๋๋ค.
# ๐ this won't work
stable_diffusion = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", torch_dtype=torch.float16)
# ๐ this works
stable_diffusion = DiffusionPipeline.from_pretrained(
"./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
)
๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
๋ชจ๋ธ๋ค์ [ModelMixin.from_pretrained
] ๋ฉ์๋๋ฅผ ํตํด ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. ํด๋น ๋ฉ์๋๋ ์ต์ ๋ฒ์ ์ ๋ชจ๋ธ ๊ฐ์ค์น ํ์ผ๊ณผ ์ค์ ํ์ผ(configurations)์ ๋ค์ด๋ก๋ํ๊ณ ์บ์ฑํฉ๋๋ค. ๋ง์ฝ ์ด๋ฌํ ํ์ผ๋ค์ด ์ต์ ๋ฒ์ ์ผ๋ก ๋ก์ปฌ ์บ์์ ์ ์ฅ๋์ด ์๋ค๋ฉด, [ModelMixin.from_pretrained
]๋ ๊ตณ์ด ํด๋น ํ์ผ๋ค์ ๋ค์ ๋ค์ด๋ก๋ํ์ง ์์ผ๋ฉฐ, ๊ทธ์ ์บ์์ ์๋ ์ต์ ํ์ผ๋ค์ ์ฌ์ฌ์ฉํฉ๋๋ค.
๋ชจ๋ธ์ subfolder
์ธ์์ ๋ช
์๋ ํ์ ํด๋๋ก๋ถํฐ ๋ก๋๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด runwayml/stable-diffusion-v1-5
์ UNet ๋ชจ๋ธ์ ๊ฐ์ค์น๋ unet
ํด๋์ ์ ์ฅ๋์ด ์์ต๋๋ค.
from diffusers import UNet2DConditionModel
repo_id = "runwayml/stable-diffusion-v1-5"
model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet")
ํน์ ํด๋น ๋ชจ๋ธ์ ๋ฆฌํฌ์งํ ๋ฆฌ๋ก๋ถํฐ ๋ค์ด๋ ํธ๋ก ๊ฐ์ ธ์ค๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
from diffusers import UNet2DModel
repo_id = "google/ddpm-cifar10-32"
model = UNet2DModel.from_pretrained(repo_id)
๋ํ ์์ ๋ดค๋ variant
์ธ์๋ฅผ ๋ช
์ํจ์ผ๋ก์จ, Non-EMA๋ fp16
์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
from diffusers import UNet2DConditionModel
model = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", variant="non-ema")
model.save_pretrained("./local-unet", variant="non-ema")
์ค์ผ์ค๋ฌ
์ค์ผ์ค๋ฌ๋ค์ [SchedulerMixin.from_pretrained
] ๋ฉ์๋๋ฅผ ํตํด ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ค์ผ์ค๋ฌ๋ ๋ณ๋์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์์ผ๋ฉฐ, ๋ฐ๋ผ์ ๋น์ฐํ ๋ณ๋์ ํ์ต๊ณผ์ ์ ์๊ตฌํ์ง ์์ต๋๋ค. ์ด๋ฌํ ์ค์ผ์ค๋ฌ๋ค์ (ํด๋น ์ค์ผ์ค๋ฌ ํ์ํด๋์) configration ํ์ผ์ ํตํด ์ ์๋ฉ๋๋ค.
์ฌ๋ฌ๊ฐ์ ์ค์ผ์ค๋ฌ๋ฅผ ๋ถ๋ฌ์จ๋ค๊ณ ํด์ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋ชจํ๋ ๊ฒ์ ์๋๋ฉฐ, ๋ค์ํ ์ค์ผ์ค๋ฌ๋ค์ ๋์ผํ ์ค์ผ์ค๋ฌ configration์ ์ ์ฉํ๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค. ๋ค์ ์์ ์ฝ๋์์ ๋ถ๋ฌ์ค๋ ์ค์ผ์ค๋ฌ๋ค์ ๋ชจ๋ [StableDiffusionPipeline
]๊ณผ ํธํ๋๋๋ฐ, ์ด๋ ๊ณง ํด๋น ์ค์ผ์ค๋ฌ๋ค์ ๋์ผํ ์ค์ผ์ค๋ฌ configration ํ์ผ์ ์ ์ฉํ ์ ์์์ ์๋ฏธํฉ๋๋ค.
from diffusers import StableDiffusionPipeline
from diffusers import (
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
repo_id = "runwayml/stable-diffusion-v1-5"
ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler_anc`, `euler`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
DiffusionPipeline์ ๋ํด ์์๋ณด๊ธฐ
ํด๋์ค ๋ฉ์๋๋ก์ [DiffusionPipeline.from_pretrained
]์ 2๊ฐ์ง๋ฅผ ๋ด๋นํฉ๋๋ค.
- ์ฒซ์งธ๋ก,
from_pretrained
๋ฉ์๋๋ ์ต์ ๋ฒ์ ์ ํ์ดํ๋ผ์ธ์ ๋ค์ด๋ก๋ํ๊ณ , ์บ์์ ์ ์ฅํฉ๋๋ค. ์ด๋ฏธ ๋ก์ปฌ ์บ์์ ์ต์ ๋ฒ์ ์ ํ์ดํ๋ผ์ธ์ด ์ ์ฅ๋์ด ์๋ค๋ฉด, [DiffusionPipeline.from_pretrained
]์ ํด๋น ํ์ผ๋ค์ ๋ค์ ๋ค์ด๋ก๋ํ์ง ์๊ณ , ๋ก์ปฌ ์บ์์ ์ ์ฅ๋์ด ์๋ ํ์ดํ๋ผ์ธ์ ๋ถ๋ฌ์ต๋๋ค. model_index.json
ํ์ผ์ ํตํด ์ฒดํฌํฌ์ธํธ์ ๋์๋๋ ์ ํฉํ ํ์ดํ๋ผ์ธ ํด๋์ค๋ก ๋ถ๋ฌ์ต๋๋ค.
ํ์ดํ๋ผ์ธ์ ํด๋ ๊ตฌ์กฐ๋ ํด๋น ํ์ดํ๋ผ์ธ ํด๋์ค์ ๊ตฌ์กฐ์ ์ง์ ์ ์ผ๋ก ์ผ์นํฉ๋๋ค. ์๋ฅผ ๋ค์ด [StableDiffusionPipeline
] ํด๋์ค๋ runwayml/stable-diffusion-v1-5
๋ฆฌํฌ์งํ ๋ฆฌ์ ๋์๋๋ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ต๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipeline = DiffusionPipeline.from_pretrained(repo_id)
print(pipeline)
์์ ์ฝ๋ ์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํด๋ณด๋ฉด, pipeline
์ [StableDiffusionPipeline
]์ ์ธ์คํด์ค์ด๋ฉฐ, ๋ค์๊ณผ ๊ฐ์ด ์ด 7๊ฐ์ ์ปดํฌ๋ํธ๋ก ๊ตฌ์ฑ๋๋ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
"feature_extractor"
: [~transformers.CLIPFeatureExtractor
]์ ์ธ์คํด์ค"safety_checker"
: ์ ํดํ ์ปจํ ์ธ ๋ฅผ ์คํฌ๋ฆฌ๋ํ๊ธฐ ์ํ ์ปดํฌ๋ํธ"scheduler"
: [PNDMScheduler
]์ ์ธ์คํด์ค"text_encoder"
: [~transformers.CLIPTextModel
]์ ์ธ์คํด์ค"tokenizer"
: a [~transformers.CLIPTokenizer
]์ ์ธ์คํด์ค"unet"
: [UNet2DConditionModel
]์ ์ธ์คํด์ค"vae"
[AutoencoderKL
]์ ์ธ์คํด์ค
StableDiffusionPipeline {
"feature_extractor": [
"transformers",
"CLIPImageProcessor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}
ํ์ดํ๋ผ์ธ ์ธ์คํด์ค์ ์ปดํฌ๋ํธ๋ค์ runwayml/stable-diffusion-v1-5
์ ํด๋ ๊ตฌ์กฐ์ ๋น๊ตํด๋ณผ ๊ฒฝ์ฐ, ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ง๋ค ๋ณ๋์ ํด๋๊ฐ ์์์ ํ์ธํ ์ ์์ต๋๋ค.
.
โโโ feature_extractor
โ โโโ preprocessor_config.json
โโโ model_index.json
โโโ safety_checker
โ โโโ config.json
โ โโโ pytorch_model.bin
โโโ scheduler
โ โโโ scheduler_config.json
โโโ text_encoder
โ โโโ config.json
โ โโโ pytorch_model.bin
โโโ tokenizer
โ โโโ merges.txt
โ โโโ special_tokens_map.json
โ โโโ tokenizer_config.json
โ โโโ vocab.json
โโโ unet
โ โโโ config.json
โ โโโ diffusion_pytorch_model.bin
โโโ vae
โโโ config.json
โโโ diffusion_pytorch_model.bin
๋ํ ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ ํ์ดํ๋ผ์ธ ์ธ์คํด์ค์ ์์ฑ์ผ๋ก์จ ์ฐธ์กฐํ ์ ์์ต๋๋ค.
pipeline.tokenizer
CLIPTokenizer(
name_or_path="/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer",
vocab_size=49408,
model_max_length=77,
is_fast=False,
padding_side="right",
truncation_side="right",
special_tokens={
"bos_token": AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
"eos_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
"unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
"pad_token": "<|endoftext|>",
},
)
๋ชจ๋ ํ์ดํ๋ผ์ธ์ model_index.json
ํ์ผ์ ํตํด [DiffusionPipeline
]์ ๋ค์๊ณผ ๊ฐ์ ์ ๋ณด๋ฅผ ์ ๋ฌํฉ๋๋ค.
_class_name
๋ ์ด๋ค ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ์ฌ์ฉํด์ผ ํ๋์ง์ ๋ํด ์๋ ค์ค๋๋ค._diffusers_version
๋ ์ด๋ค ๋ฒ์ ์ ๋ํจ์ ์ค๋ก ํ์ดํ๋ผ์ธ ์์ ๋ชจ๋ธ๋ค์ด ๋ง๋ค์ด์ก๋์ง๋ฅผ ์๋ ค์ค๋๋ค.- ๊ทธ ๋ค์์ ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ด ์ด๋ค ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ด๋ค ํด๋์ค๋ก ๋ง๋ค์ด์ก๋์ง์ ๋ํด ์๋ ค์ค๋๋ค. (์๋ ์์์์
"feature_extractor" : ["transformers", "CLIPImageProcessor"]
์ ๊ฒฝ์ฐ,feature_extractor
์ปดํฌ๋ํธ๋transformers
๋ผ์ด๋ธ๋ฌ๋ฆฌ์CLIPImageProcessor
ํด๋์ค๋ฅผ ํตํด ๋ง๋ค์ด์ก๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.)
{
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.6.0",
"feature_extractor": [
"transformers",
"CLIPImageProcessor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}