yibolu
commited on
Commit
•
6eca12e
1
Parent(s):
a38262d
update pipeline and demos
Browse files- README.md +19 -11
- controlnet_img2img_demo.py +6 -4
- controlnet_txt2img_demo.py +11 -5
- controlnet_txt2img_sdxl_demo.py +70 -0
- img2img_demo.py +5 -2
- lyrasd_model/__init__.py +5 -1
- lyrasd_model/lora_util.py +238 -6
- lyrasd_model/lyrasd_controlnet_img2img_pipeline.py +92 -110
- lyrasd_model/lyrasd_controlnet_txt2img_pipeline.py +40 -82
- lyrasd_model/lyrasd_img2img_pipeline.py +90 -95
- lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm80.so +0 -3
- lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm86.so +0 -3
- lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so +2 -2
- lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so +2 -2
- lyrasd_model/lyrasd_pipeline_base.py +214 -0
- lyrasd_model/lyrasd_txt2img_inpaint_pipeline.py +826 -0
- lyrasd_model/lyrasd_txt2img_pipeline.py +172 -85
- lyrasd_model/lyrasd_vae_model.py +363 -0
- lyrasd_model/lyrasdxl_controlnet_txt2img_pipeline.py +346 -0
- lyrasd_model/lyrasdxl_pipeline_base.py +275 -0
- lyrasd_model/lyrasdxl_txt2img_inpaint_pipeline.py +535 -0
- lyrasd_model/lyrasdxl_txt2img_pipeline.py +267 -0
- lyrasd_model/{lyrasd_lib/placeholder.txt → module/__init__.py} +0 -0
- lyrasd_model/module/lyra_tool.py +5 -0
- lyrasd_model/module/lyrasd_ip_adapter.py +289 -0
- lyrasd_model/module/resampler.py +121 -0
- lyrasd_model/module/tools.py +148 -0
- models/README.md +14 -5
- outputs/res_controlnet_img2img_0.png +2 -2
- outputs/{res_controlnet_sdxl_txt2img.png → res_controlnet_sdxl_txt2img_0.png} +2 -2
- outputs/res_controlnet_txt2img_0.png +2 -2
- outputs/res_img2img_0.png +2 -2
- outputs/res_txt2img_lora_0.png +2 -2
- outputs/{res_sdxl_txt2img_lora_0.png → res_txt2img_xl_lora_0.png} +2 -2
- txt2img_demo.py +13 -10
- txt2img_sdxl_demo.py +55 -0
README.md
CHANGED
@@ -79,12 +79,16 @@ from lyrasd_model import LyraSdTxt2ImgPipeline
|
|
79 |
# 4. scheduler 配置
|
80 |
|
81 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
82 |
-
lib_path = "./lyrasd_model/lyrasd_lib/
|
83 |
-
model_path = "./models/
|
84 |
lora_path = "./models/xiaorenshu.safetensors"
|
85 |
|
|
|
|
|
86 |
# 构建 Txt2Img 的 Pipeline
|
87 |
-
model = LyraSdTxt2ImgPipeline(
|
|
|
|
|
88 |
|
89 |
# load lora
|
90 |
# lora model path, name,lora strength
|
@@ -94,7 +98,7 @@ model.load_lora_v2(lora_path, "xiaorenshu", 0.4)
|
|
94 |
prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
|
95 |
negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
|
96 |
height, width = 512, 512
|
97 |
-
steps =
|
98 |
guidance_scale = 7
|
99 |
generator = torch.Generator().manual_seed(123)
|
100 |
num_images = 1
|
@@ -128,12 +132,16 @@ from lyrasd_model import LyraSdXLTxt2ImgPipeline
|
|
128 |
# 4. scheduler 配置
|
129 |
|
130 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
131 |
-
lib_path = "./lyrasd_model/lyrasd_lib/
|
132 |
-
model_path = "./models/
|
133 |
lora_path = "./models/dissolve_sdxl.safetensors"
|
134 |
|
|
|
|
|
135 |
# 构建 Txt2Img 的 Pipeline
|
136 |
-
model = LyraSdXLTxt2ImgPipeline(
|
|
|
|
|
137 |
|
138 |
# load lora
|
139 |
# lora model path, name,lora strength
|
@@ -143,7 +151,7 @@ model.load_lora_v2(lora_path, "dissolve_sdxl", 0.4)
|
|
143 |
prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
|
144 |
negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
|
145 |
height, width = 512, 512
|
146 |
-
steps =
|
147 |
guidance_scale = 7
|
148 |
generator = torch.Generator().manual_seed(123)
|
149 |
num_images = 1
|
@@ -181,7 +189,7 @@ model.unload_lora_v2("dissolve_sdxl", True)
|
|
181 |
![text2img_demo](./outputs/res_sdxl_txt2img_0.png)
|
182 |
|
183 |
#### SDXL Text2Img with Lora
|
184 |
-
![text2img_demo](./outputs/
|
185 |
|
186 |
|
187 |
<!-- ### Img2Img
|
@@ -201,7 +209,7 @@ model.unload_lora_v2("dissolve_sdxl", True)
|
|
201 |
![text2img_demo](./outputs/res_controlnet_txt2img_0.png)
|
202 |
|
203 |
#### SDXL ControlNet Text2Img Output
|
204 |
-
![text2img_demo](./outputs/
|
205 |
|
206 |
|
207 |
## Docker Environment Recommendation
|
@@ -218,7 +226,7 @@ python txt2img_demo.py
|
|
218 |
|
219 |
## Citation
|
220 |
``` bibtex
|
221 |
-
@Misc{
|
222 |
author = {Kangjian Wu, Zhengtao Wang, Yibo Lu, Haoxiong Su, Sa Xiao, Bin Wu},
|
223 |
title = {lyraSD: Accelerating Stable Diffusion with best flexibility},
|
224 |
howpublished = {\url{https://huggingface.co/TMElyralab/lyraSD}},
|
|
|
79 |
# 4. scheduler 配置
|
80 |
|
81 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
82 |
+
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
|
83 |
+
model_path = "./models/rev-animated"
|
84 |
lora_path = "./models/xiaorenshu.safetensors"
|
85 |
|
86 |
+
torch.classes.load_library(lib_path)
|
87 |
+
|
88 |
# 构建 Txt2Img 的 Pipeline
|
89 |
+
model = LyraSdTxt2ImgPipeline()
|
90 |
+
|
91 |
+
model.reload_pipe(model_path)
|
92 |
|
93 |
# load lora
|
94 |
# lora model path, name,lora strength
|
|
|
98 |
prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
|
99 |
negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
|
100 |
height, width = 512, 512
|
101 |
+
steps = 20
|
102 |
guidance_scale = 7
|
103 |
generator = torch.Generator().manual_seed(123)
|
104 |
num_images = 1
|
|
|
132 |
# 4. scheduler 配置
|
133 |
|
134 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
135 |
+
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
|
136 |
+
model_path = "./models/helloworldSDXL20Fp16"
|
137 |
lora_path = "./models/dissolve_sdxl.safetensors"
|
138 |
|
139 |
+
torch.classes.load_library(lib_path)
|
140 |
+
|
141 |
# 构建 Txt2Img 的 Pipeline
|
142 |
+
model = LyraSdXLTxt2ImgPipeline()
|
143 |
+
|
144 |
+
model.reload_pipe(model_path)
|
145 |
|
146 |
# load lora
|
147 |
# lora model path, name,lora strength
|
|
|
151 |
prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
|
152 |
negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
|
153 |
height, width = 512, 512
|
154 |
+
steps = 20
|
155 |
guidance_scale = 7
|
156 |
generator = torch.Generator().manual_seed(123)
|
157 |
num_images = 1
|
|
|
189 |
![text2img_demo](./outputs/res_sdxl_txt2img_0.png)
|
190 |
|
191 |
#### SDXL Text2Img with Lora
|
192 |
+
![text2img_demo](./outputs/res_txt2img_xl_lora_0.png)
|
193 |
|
194 |
|
195 |
<!-- ### Img2Img
|
|
|
209 |
![text2img_demo](./outputs/res_controlnet_txt2img_0.png)
|
210 |
|
211 |
#### SDXL ControlNet Text2Img Output
|
212 |
+
![text2img_demo](./outputs/res_controlnet_sdxl_txt2img_0.png)
|
213 |
|
214 |
|
215 |
## Docker Environment Recommendation
|
|
|
226 |
|
227 |
## Citation
|
228 |
``` bibtex
|
229 |
+
@Misc{lyraSD_2024,
|
230 |
author = {Kangjian Wu, Zhengtao Wang, Yibo Lu, Haoxiong Su, Sa Xiao, Bin Wu},
|
231 |
title = {lyraSD: Accelerating Stable Diffusion with best flexibility},
|
232 |
howpublished = {\url{https://huggingface.co/TMElyralab/lyraSD}},
|
controlnet_img2img_demo.py
CHANGED
@@ -14,14 +14,16 @@ from lyrasd_model import LyraSdControlnetImg2ImgPipeline
|
|
14 |
|
15 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
16 |
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
|
17 |
-
model_path = "./models/
|
18 |
-
canny_controlnet_path = "./models/
|
|
|
19 |
|
20 |
# 构建 Img2Img 的 Pipeline
|
21 |
-
model = LyraSdControlnetImg2ImgPipeline(
|
|
|
22 |
|
23 |
# load Controlnet 模型,最多load 3个
|
24 |
-
model.
|
25 |
|
26 |
control_img = Image.open("control_bird_canny.png")
|
27 |
|
|
|
14 |
|
15 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
16 |
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
|
17 |
+
model_path = "./models/rev-animated"
|
18 |
+
canny_controlnet_path = "./models/canny"
|
19 |
+
torch.classes.load_library(lib_path)
|
20 |
|
21 |
# 构建 Img2Img 的 Pipeline
|
22 |
+
model = LyraSdControlnetImg2ImgPipeline()
|
23 |
+
model.reload_pipe(model_path)
|
24 |
|
25 |
# load Controlnet 模型,最多load 3个
|
26 |
+
model.load_controlnet_model_v2("canny", canny_controlnet_path)
|
27 |
|
28 |
control_img = Image.open("control_bird_canny.png")
|
29 |
|
controlnet_txt2img_demo.py
CHANGED
@@ -12,16 +12,22 @@ from lyrasd_model import LyraSdControlnetTxt2ImgPipeline
|
|
12 |
# 5. scheduler 配置
|
13 |
|
14 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
15 |
-
lib_path = "./lyrasd_model/lyrasd_lib/
|
16 |
-
model_path = "./models/
|
17 |
-
canny_controlnet_path = "./models/
|
|
|
|
|
|
|
18 |
# 构建 Txt2Img 的 Pipeline
|
19 |
-
pipe = LyraSdControlnetTxt2ImgPipeline(
|
|
|
|
|
20 |
|
21 |
# load Controlnet 模型,最多load 3个
|
22 |
start = time.perf_counter()
|
23 |
-
pipe.
|
24 |
print(f"controlnet load cost: {time.perf_counter() - start}")
|
|
|
25 |
# 可以通过 get_loaded_controlnet 方法获取目前已经load 好的Controlnet list
|
26 |
print(pipe.get_loaded_controlnet())
|
27 |
|
|
|
12 |
# 5. scheduler 配置
|
13 |
|
14 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
15 |
+
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
|
16 |
+
model_path = "./models/rev-animated"
|
17 |
+
canny_controlnet_path = "./models/canny"
|
18 |
+
|
19 |
+
torch.classes.load_library(lib_path)
|
20 |
+
|
21 |
# 构建 Txt2Img 的 Pipeline
|
22 |
+
pipe = LyraSdControlnetTxt2ImgPipeline()
|
23 |
+
|
24 |
+
pipe.reload_pipe(model_path)
|
25 |
|
26 |
# load Controlnet 模型,最多load 3个
|
27 |
start = time.perf_counter()
|
28 |
+
pipe.load_controlnet_model_v2("canny", canny_controlnet_path)
|
29 |
print(f"controlnet load cost: {time.perf_counter() - start}")
|
30 |
+
|
31 |
# 可以通过 get_loaded_controlnet 方法获取目前已经load 好的Controlnet list
|
32 |
print(pipe.get_loaded_controlnet())
|
33 |
|
controlnet_txt2img_sdxl_demo.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import time
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from lyrasd_model import LyraSdXLControlnetTxt2ImgPipeline
|
6 |
+
import GPUtil
|
7 |
+
|
8 |
+
# 存放模型文件的路径,应该包含一下结构:
|
9 |
+
# 1. clip 模型
|
10 |
+
# 2. 转换好的优化后的 unet 模型
|
11 |
+
# 3. 转换好的优化后的 controlnet 模型
|
12 |
+
# 4. vae 模型
|
13 |
+
# 5. scheduler 配置
|
14 |
+
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
|
15 |
+
model_path = "./models/helloworldSDXL20Fp16"
|
16 |
+
torch.classes.load_library(lib_path)
|
17 |
+
|
18 |
+
# 构建 Txt2Img 的 Pipeline
|
19 |
+
pipe = LyraSdXLControlnetTxt2ImgPipeline()
|
20 |
+
|
21 |
+
start = time.perf_counter()
|
22 |
+
pipe.reload_pipe(model_path)
|
23 |
+
print(f"pipeline load cost: {time.perf_counter() - start}")
|
24 |
+
|
25 |
+
# load Controlnet 模型,最多load 3个
|
26 |
+
start = time.perf_counter()
|
27 |
+
pipe.load_controlnet_model_v2("canny", "./models/controlnet-canny-sdxl-1.0")
|
28 |
+
print(f"controlnet load cost: {time.perf_counter() - start}")
|
29 |
+
|
30 |
+
# 可以通过 get_loaded_controlnet 方法获取目前已经load 好的Controlnet list
|
31 |
+
print(pipe.get_loaded_controlnet())
|
32 |
+
|
33 |
+
# 可以通过unload_controlnet_model 方法unload Controlnet
|
34 |
+
# pipe.unload_controlnet_model("canny")
|
35 |
+
|
36 |
+
control_img = Image.open("control_bird_canny.png")
|
37 |
+
|
38 |
+
# 准备应用的输入和超参数
|
39 |
+
prompt = "a bird"
|
40 |
+
negative_prompt = ""
|
41 |
+
height, width = 1024, 1024
|
42 |
+
steps = 20
|
43 |
+
guidance_scale = 7.5
|
44 |
+
generator = torch.Generator().manual_seed(123)
|
45 |
+
num_images = 1
|
46 |
+
guess_mode = False
|
47 |
+
|
48 |
+
# 可以一次性load 3 个 Controlnets,达到multi Controlnet的效果,这里的参数的长度需要对其
|
49 |
+
# Controlnet 所输入的img list 长度应该和 controlnet scale 与 Controlnet name 一致,而内部的list长度需要和batch size一致
|
50 |
+
# 对应的index 可以对其
|
51 |
+
controlnet_images = [[control_img]]
|
52 |
+
controlnet_scale = [0.5]
|
53 |
+
controlnet_names = ['canny']
|
54 |
+
|
55 |
+
# 推理生成,返回结果都是生成好的 PIL.Image
|
56 |
+
for batch in [1]:
|
57 |
+
print(f"cur batch: {batch}")
|
58 |
+
for _ in range(3):
|
59 |
+
start = time.perf_counter()
|
60 |
+
images = pipe(prompt=prompt, height=height, width=width, num_inference_steps=steps,
|
61 |
+
guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=batch,
|
62 |
+
generator=generator, controlnet_images=controlnet_images,
|
63 |
+
controlnet_scale=controlnet_scale, controlnet_names=controlnet_names,
|
64 |
+
guess_mode=guess_mode
|
65 |
+
)
|
66 |
+
print("cur cost: ", time.perf_counter() - start)
|
67 |
+
GPUtil.showUtilization(all=True)
|
68 |
+
# 存储生成的图片
|
69 |
+
for i, image in enumerate(images):
|
70 |
+
image.save(f"./outputs/res_controlnet_sdxl_txt2img_{i}.png")
|
img2img_demo.py
CHANGED
@@ -14,10 +14,13 @@ from lyrasd_model import LyraSDImg2ImgPipeline
|
|
14 |
|
15 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
16 |
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
|
17 |
-
model_path = "./models/
|
|
|
|
|
18 |
|
19 |
# 构建 Img2Img 的 Pipeline
|
20 |
-
model = LyraSDImg2ImgPipeline(
|
|
|
21 |
|
22 |
# 准备应用的输入和超参数
|
23 |
prompt = "a cat, cartoon style"
|
|
|
14 |
|
15 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
16 |
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so"
|
17 |
+
model_path = "./models/rev-animated"
|
18 |
+
|
19 |
+
torch.classes.load_library(lib_path)
|
20 |
|
21 |
# 构建 Img2Img 的 Pipeline
|
22 |
+
model = LyraSDImg2ImgPipeline()
|
23 |
+
model.reload_pipe(model_path)
|
24 |
|
25 |
# 准备应用的输入和超参数
|
26 |
prompt = "a cat, cartoon style"
|
lyrasd_model/__init__.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
from . import lyrasd_img2img_pipeline, lyrasd_txt2img_pipeline, lyrasd_controlnet_txt2img_pipeline, lyrasd_controlnet_img2img_pipeline
|
2 |
from .lyrasd_txt2img_pipeline import LyraSdTxt2ImgPipeline
|
3 |
from .lyrasd_img2img_pipeline import LyraSDImg2ImgPipeline
|
|
|
4 |
from .lyrasd_controlnet_txt2img_pipeline import LyraSdControlnetTxt2ImgPipeline
|
5 |
-
from .lyrasd_controlnet_img2img_pipeline import LyraSdControlnetImg2ImgPipeline
|
|
|
|
|
|
|
|
1 |
from . import lyrasd_img2img_pipeline, lyrasd_txt2img_pipeline, lyrasd_controlnet_txt2img_pipeline, lyrasd_controlnet_img2img_pipeline
|
2 |
from .lyrasd_txt2img_pipeline import LyraSdTxt2ImgPipeline
|
3 |
from .lyrasd_img2img_pipeline import LyraSDImg2ImgPipeline
|
4 |
+
from .lyrasd_txt2img_inpaint_pipeline import LyraSdTxt2ImgInpaintPipeline
|
5 |
from .lyrasd_controlnet_txt2img_pipeline import LyraSdControlnetTxt2ImgPipeline
|
6 |
+
from .lyrasd_controlnet_img2img_pipeline import LyraSdControlnetImg2ImgPipeline
|
7 |
+
from .lyrasdxl_txt2img_pipeline import LyraSdXLTxt2ImgPipeline
|
8 |
+
from .lyrasdxl_controlnet_txt2img_pipeline import LyraSdXLControlnetTxt2ImgPipeline
|
9 |
+
from .lyrasdxl_txt2img_inpaint_pipeline import LyraSdXLTxt2ImgInpaintPipeline
|
lyrasd_model/lora_util.py
CHANGED
@@ -1,7 +1,18 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import torch
|
3 |
-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
4 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"):
|
7 |
if lora_file_format == "fp32":
|
@@ -14,9 +25,10 @@ def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=
|
|
14 |
unload_dict = []
|
15 |
# directly update weight in diffusers model
|
16 |
for file in all_files:
|
17 |
-
|
18 |
if 'text' in file.name:
|
19 |
-
layer_infos = file.name.split('.')[0].split(
|
|
|
20 |
curr_layer = clip_model.text_model
|
21 |
else:
|
22 |
continue
|
@@ -39,9 +51,71 @@ def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=
|
|
39 |
temp_name += '_'+layer_infos.pop(0)
|
40 |
else:
|
41 |
temp_name = layer_infos.pop(0)
|
42 |
-
data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
if len(curr_layer.weight.data) == 4:
|
44 |
-
adding_weight = alpha * data.permute(0,3,1,2)
|
45 |
else:
|
46 |
adding_weight = alpha * data
|
47 |
curr_layer.weight.data += adding_weight
|
@@ -51,4 +125,162 @@ def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=
|
|
51 |
"added_weight": adding_weight
|
52 |
}
|
53 |
unload_dict.append(curr_layer_unload_data)
|
54 |
-
return unload_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
import torch
|
|
|
5 |
import numpy as np
|
6 |
+
from safetensors.torch import load_file
|
7 |
+
from diffusers.loaders import LoraLoaderMixin
|
8 |
+
from diffusers.loaders.lora_conversion_utils import _maybe_map_sgm_blocks_to_diffusers, _convert_kohya_lora_to_diffusers
|
9 |
+
from types import SimpleNamespace
|
10 |
+
import logging.handlers
|
11 |
+
LORA_PREFIX_UNET = "lora_unet"
|
12 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
13 |
+
LORA_UNET_LAYERS = ['lora_unet_down_blocks_0_attentions_0', 'lora_unet_down_blocks_0_attentions_1', 'lora_unet_down_blocks_1_attentions_0', 'lora_unet_down_blocks_1_attentions_1', 'lora_unet_down_blocks_2_attentions_0', 'lora_unet_down_blocks_2_attentions_1', 'lora_unet_mid_block_attentions_0', 'lora_unet_up_blocks_1_attentions_0',
|
14 |
+
'lora_unet_up_blocks_1_attentions_1', 'lora_unet_up_blocks_1_attentions_2', 'lora_unet_up_blocks_2_attentions_0', 'lora_unet_up_blocks_2_attentions_1', 'lora_unet_up_blocks_2_attentions_2', 'lora_unet_up_blocks_3_attentions_0', 'lora_unet_up_blocks_3_attentions_1', 'lora_unet_up_blocks_3_attentions_2']
|
15 |
+
|
16 |
|
17 |
def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"):
|
18 |
if lora_file_format == "fp32":
|
|
|
25 |
unload_dict = []
|
26 |
# directly update weight in diffusers model
|
27 |
for file in all_files:
|
28 |
+
|
29 |
if 'text' in file.name:
|
30 |
+
layer_infos = file.name.split('.')[0].split(
|
31 |
+
'text_model_')[-1].split('_')
|
32 |
curr_layer = clip_model.text_model
|
33 |
else:
|
34 |
continue
|
|
|
51 |
temp_name += '_'+layer_infos.pop(0)
|
52 |
else:
|
53 |
temp_name = layer_infos.pop(0)
|
54 |
+
data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(
|
55 |
+
clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape)
|
56 |
+
if len(curr_layer.weight.data) == 4:
|
57 |
+
adding_weight = alpha * data.permute(0, 3, 1, 2)
|
58 |
+
else:
|
59 |
+
adding_weight = alpha * data
|
60 |
+
curr_layer.weight.data += adding_weight
|
61 |
+
|
62 |
+
curr_layer_unload_data = {
|
63 |
+
"layer": curr_layer,
|
64 |
+
"added_weight": adding_weight
|
65 |
+
}
|
66 |
+
unload_dict.append(curr_layer_unload_data)
|
67 |
+
return unload_dict
|
68 |
+
|
69 |
+
|
70 |
+
def add_xltext_lora_layer(clip_model, clip_model_2, lora_model_path, alpha=1.0, lora_file_format="fp32", device="cuda:0"):
|
71 |
+
if lora_file_format == "fp32":
|
72 |
+
model_dtype = np.float32
|
73 |
+
elif lora_file_format == "fp16":
|
74 |
+
model_dtype = np.float16
|
75 |
+
else:
|
76 |
+
raise Exception(f"unsupported model dtype: {lora_file_format}")
|
77 |
+
all_files = os.scandir(lora_model_path)
|
78 |
+
unload_dict = []
|
79 |
+
# directly update weight in diffusers model
|
80 |
+
for file in all_files:
|
81 |
+
|
82 |
+
if 'text' in file.name:
|
83 |
+
layer_infos = file.name.split('.')[0].split(
|
84 |
+
'text_model_')[-1].split('_')
|
85 |
+
if "text_encoder_2" in file.name:
|
86 |
+
curr_layer = clip_model_2.text_model
|
87 |
+
elif "text_encoder" in file.name:
|
88 |
+
curr_layer = clip_model.text_model
|
89 |
+
else:
|
90 |
+
raise ValueError(
|
91 |
+
"Cannot identify clip model, need text_encoder or text_encoder_2 in filename, found: ", file.name)
|
92 |
+
else:
|
93 |
+
continue
|
94 |
+
|
95 |
+
# find the target layer
|
96 |
+
# find the target layer
|
97 |
+
temp_name = layer_infos.pop(0)
|
98 |
+
while len(layer_infos) > -1:
|
99 |
+
try:
|
100 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
101 |
+
if len(layer_infos) > 0:
|
102 |
+
temp_name = layer_infos.pop(0)
|
103 |
+
# if temp_name == "self":
|
104 |
+
# temp_name += "_" + layer_infos.pop(0)
|
105 |
+
# elif temp_name != "mlp" and len(layer_infos) == 1:
|
106 |
+
# temp_name += "_" + layer_infos.pop(0)
|
107 |
+
elif len(layer_infos) == 0:
|
108 |
+
break
|
109 |
+
except Exception:
|
110 |
+
if len(temp_name) > 0:
|
111 |
+
temp_name += '_'+layer_infos.pop(0)
|
112 |
+
else:
|
113 |
+
temp_name = layer_infos.pop(0)
|
114 |
+
|
115 |
+
data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(
|
116 |
+
clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape)
|
117 |
if len(curr_layer.weight.data) == 4:
|
118 |
+
adding_weight = alpha * data.permute(0, 3, 1, 2)
|
119 |
else:
|
120 |
adding_weight = alpha * data
|
121 |
curr_layer.weight.data += adding_weight
|
|
|
125 |
"added_weight": adding_weight
|
126 |
}
|
127 |
unload_dict.append(curr_layer_unload_data)
|
128 |
+
return unload_dict
|
129 |
+
|
130 |
+
def lora_trans(state_dict):
|
131 |
+
loraload = LoraLoaderMixin()
|
132 |
+
unet_config = SimpleNamespace(**{'layers_per_block': 2})
|
133 |
+
state_dicts = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
134 |
+
state_dicts_trans, state_dicts_alpha = _convert_kohya_lora_to_diffusers(
|
135 |
+
state_dicts)
|
136 |
+
keys = list(state_dicts_trans.keys())
|
137 |
+
for k in keys:
|
138 |
+
key = k.replace('processor.', '')
|
139 |
+
for x in ['.lora_linear_layer.', '_lora.', '.lora.']:
|
140 |
+
key = key.replace(x, '.lora_')
|
141 |
+
if key.find('text_encoder') >= 0:
|
142 |
+
for x in ['q', 'k', 'v', 'out']:
|
143 |
+
key = key.replace(f'.to_{x}.', f'.{x}_proj.')
|
144 |
+
key = key.replace('to_out.', 'to_out.0.')
|
145 |
+
if key != k:
|
146 |
+
state_dicts_trans[key] = state_dicts_trans.pop(k)
|
147 |
+
alpha = torch.Tensor(list(set(list(state_dicts_alpha.values()))))
|
148 |
+
state_dicts_trans.update({'lora.alpha': alpha})
|
149 |
+
|
150 |
+
return state_dicts_trans
|
151 |
+
|
152 |
+
|
153 |
+
def load_state_dict(filename, need_trans=True):
|
154 |
+
state_dict = load_file(os.path.abspath(filename), device="cpu")
|
155 |
+
if need_trans:
|
156 |
+
state_dict = lora_trans(state_dict)
|
157 |
+
return state_dict
|
158 |
+
|
159 |
+
|
160 |
+
def move_state_dict_to_cuda(state_dict):
|
161 |
+
ret_state_dict = {}
|
162 |
+
for item in state_dict:
|
163 |
+
ret_state_dict[item] = state_dict[item].cuda()
|
164 |
+
return ret_state_dict
|
165 |
+
|
166 |
+
|
167 |
+
def add_lora_to_opt_model(state_dict, unet, clip_model, clip_model_2, alpha=1.0, need_trans=False):
|
168 |
+
# directly update weight in diffusers model
|
169 |
+
state_dict = move_state_dict_to_cuda(state_dict)
|
170 |
+
|
171 |
+
alpha_ks = list(filter(lambda x: x.find('.alpha') >= 0, state_dict))
|
172 |
+
lora_alpha = state_dict[alpha_ks[0]].item() if len(alpha_ks) > 0 else -1
|
173 |
+
|
174 |
+
visited = set()
|
175 |
+
for key in state_dict:
|
176 |
+
# print(key)
|
177 |
+
# it is suggested to print out the key, it usually will be something like below
|
178 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
179 |
+
|
180 |
+
# as we have set the alpha beforehand, so just skip
|
181 |
+
if '.alpha' in key or key in visited:
|
182 |
+
continue
|
183 |
+
|
184 |
+
if "text" in key:
|
185 |
+
curr_layer = clip_model_2 if key.find(
|
186 |
+
'text_encoder_2') >= 0 else clip_model
|
187 |
+
|
188 |
+
# if is_sdxl:
|
189 |
+
layer_infos = key.split('.')[1:]
|
190 |
+
|
191 |
+
for x in layer_infos:
|
192 |
+
try:
|
193 |
+
curr_layer = curr_layer.__getattr__(x)
|
194 |
+
except Exception:
|
195 |
+
break
|
196 |
+
|
197 |
+
# update weight
|
198 |
+
pair_keys = [key.replace("lora_down", "lora_up"),
|
199 |
+
key.replace("lora_up", "lora_down")]
|
200 |
+
weight_up, weight_down = state_dict[pair_keys[0]
|
201 |
+
], state_dict[pair_keys[1]]
|
202 |
+
|
203 |
+
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
|
204 |
+
|
205 |
+
if len(weight_up.shape) == 4:
|
206 |
+
weight_up = weight_up.squeeze([2, 3])
|
207 |
+
weight_down = weight_down.squeeze([2, 3])
|
208 |
+
if len(weight_down.shape) == 4:
|
209 |
+
adding_weight = torch.einsum(
|
210 |
+
'a b, b c h w -> a c h w', weight_up, weight_down)
|
211 |
+
else:
|
212 |
+
adding_weight = torch.mm(
|
213 |
+
weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
214 |
+
else:
|
215 |
+
adding_weight = torch.mm(weight_up, weight_down)
|
216 |
+
adding_weight = alpha * weight_scale * adding_weight
|
217 |
+
|
218 |
+
curr_layer.weight.data += adding_weight.to(torch.float16)
|
219 |
+
# update visited list
|
220 |
+
for item in pair_keys:
|
221 |
+
visited.add(item)
|
222 |
+
|
223 |
+
elif "unet" in key:
|
224 |
+
layer_infos = key
|
225 |
+
layer_infos = layer_infos.replace(".lora_up.weight", "")
|
226 |
+
layer_infos = layer_infos.replace(".lora_down.weight", "")
|
227 |
+
|
228 |
+
layer_infos = layer_infos[5:]
|
229 |
+
layer_names = layer_infos.split(".")
|
230 |
+
|
231 |
+
layers = []
|
232 |
+
i = 0
|
233 |
+
while i < len(layer_names):
|
234 |
+
|
235 |
+
if len(layers) >= 4:
|
236 |
+
layers[-1] += "_" + layer_names[i]
|
237 |
+
elif i + 1 < len(layer_names) and layer_names[i+1].isdigit():
|
238 |
+
layers.append(layer_names[i] + "_" + layer_names[i+1])
|
239 |
+
i += 1
|
240 |
+
elif len(layers) > 0 and "samplers" in layers[-1]:
|
241 |
+
layers[-1] += "_" + layer_names[i]
|
242 |
+
else:
|
243 |
+
layers.append(layer_names[i])
|
244 |
+
i += 1
|
245 |
+
layer_infos = ".".join(layers)
|
246 |
+
|
247 |
+
pair_keys = [key.replace("lora_down", "lora_up"),
|
248 |
+
key.replace("lora_up", "lora_down")]
|
249 |
+
|
250 |
+
# update weight
|
251 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
252 |
+
weight_up = state_dict[pair_keys[0]].squeeze(
|
253 |
+
3).squeeze(2).to(torch.float32)
|
254 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
255 |
+
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
|
256 |
+
|
257 |
+
weight_up, weight_down = state_dict[pair_keys[0]
|
258 |
+
], state_dict[pair_keys[1]]
|
259 |
+
weight_up = weight_up.squeeze([2, 3]).to(torch.float32)
|
260 |
+
weight_down = weight_down.squeeze([2, 3]).to(torch.float32)
|
261 |
+
if len(weight_down.shape) == 4:
|
262 |
+
curr_layer_weight = weight_scale * \
|
263 |
+
torch.einsum('a b, b c h w -> a c h w',
|
264 |
+
weight_up, weight_down)
|
265 |
+
else:
|
266 |
+
curr_layer_weight = weight_scale * \
|
267 |
+
torch.mm(weight_up, weight_down).unsqueeze(
|
268 |
+
2).unsqueeze(3)
|
269 |
+
|
270 |
+
curr_layer_weight = curr_layer_weight.permute(0, 2, 3, 1)
|
271 |
+
|
272 |
+
else:
|
273 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
274 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
275 |
+
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
|
276 |
+
|
277 |
+
curr_layer_weight = weight_scale * \
|
278 |
+
torch.mm(weight_up, weight_down)
|
279 |
+
#
|
280 |
+
|
281 |
+
curr_layer_weight = curr_layer_weight.to(torch.float16)
|
282 |
+
|
283 |
+
unet.load_lora_by_name(layers, curr_layer_weight, alpha)
|
284 |
+
|
285 |
+
for item in pair_keys:
|
286 |
+
visited.add(item)
|
lyrasd_model/lyrasd_controlnet_img2img_pipeline.py
CHANGED
@@ -1,21 +1,18 @@
|
|
1 |
import torch
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers
|
4 |
from diffusers.loaders import TextualInversionLoaderMixin
|
5 |
-
from diffusers.
|
6 |
-
from diffusers.utils import randn_tensor, logging
|
7 |
-
from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
8 |
from diffusers.utils import PIL_INTERPOLATION
|
9 |
-
|
10 |
import os
|
11 |
import numpy as np
|
12 |
import warnings
|
13 |
-
from .lora_util import add_text_lora_layer
|
14 |
-
import gc
|
15 |
|
16 |
from PIL import Image
|
17 |
import PIL
|
18 |
|
|
|
|
|
19 |
import inspect
|
20 |
|
21 |
import time
|
@@ -31,7 +28,8 @@ def numpy_to_pil(images):
|
|
31 |
images = (images * 255).round().astype("uint8")
|
32 |
if images.shape[-1] == 1:
|
33 |
# special case for grayscale (single channel) images
|
34 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L")
|
|
|
35 |
else:
|
36 |
pil_images = [Image.fromarray(image) for image in images]
|
37 |
|
@@ -53,7 +51,8 @@ def preprocess(image):
|
|
53 |
w, h = image[0].size
|
54 |
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
55 |
|
56 |
-
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
|
|
|
57 |
image = np.concatenate(image, axis=0)
|
58 |
image = np.array(image).astype(np.float32) / 255.0
|
59 |
image = image.transpose(0, 3, 1, 2)
|
@@ -63,69 +62,11 @@ def preprocess(image):
|
|
63 |
image = torch.cat(image, dim=0)
|
64 |
return image
|
65 |
|
66 |
-
class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
67 |
-
def __init__(self, model_path, lib_so_path, model_dtype='fp32', device=torch.device("cuda"), dtype=torch.float16) -> None:
|
68 |
-
self.device = device
|
69 |
-
self.dtype = dtype
|
70 |
-
|
71 |
-
torch.classes.load_library(lib_so_path)
|
72 |
-
|
73 |
-
self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
|
74 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
75 |
-
self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
|
76 |
-
|
77 |
-
self.unet_in_channels = 4
|
78 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
79 |
-
self.vae.enable_tiling()
|
80 |
-
self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
|
81 |
-
3, # max num of controlnets
|
82 |
-
"fp16" # inference dtype (can only use fp16 for now)
|
83 |
-
)
|
84 |
-
|
85 |
-
unet_path = os.path.join(model_path, "unet_bins/")
|
86 |
-
|
87 |
-
self.reload_unet_model(unet_path, model_dtype)
|
88 |
-
|
89 |
-
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
90 |
-
|
91 |
-
def load_controlnet_model(self, model_name, controlnet_path, model_dtype="fp32"):
|
92 |
-
if len(controlnet_path) > 0 and controlnet_path[-1] != "/":
|
93 |
-
controlnet_path = controlnet_path + "/"
|
94 |
-
self.unet.load_controlnet_model(model_name, controlnet_path, model_dtype)
|
95 |
-
|
96 |
-
def unload_controlnet_model(self, model_name):
|
97 |
-
self.unet.unload_controlnet_model(model_name, True)
|
98 |
-
|
99 |
-
def get_loaded_controlnet(self):
|
100 |
-
return self.unet.get_loaded_controlnet()
|
101 |
-
|
102 |
-
def reload_unet_model(self, unet_path, unet_file_format='fp32'):
|
103 |
-
if len(unet_path) > 0 and unet_path[-1] != "/":
|
104 |
-
unet_path = unet_path + "/"
|
105 |
-
return self.unet.reload_unet_model(unet_path, unet_file_format)
|
106 |
-
|
107 |
-
def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
|
108 |
-
if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
|
109 |
-
lora_model_path = lora_model_path + "/"
|
110 |
-
lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
|
111 |
-
self.loaded_lora[lora_name] = lora
|
112 |
-
self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
layer.weight.data -= added_weight
|
119 |
-
self.unet.unload_lora(lora_name, clean_cache)
|
120 |
-
del self.loaded_lora[lora_name]
|
121 |
-
gc.collect()
|
122 |
-
torch.cuda.empty_cache()
|
123 |
-
|
124 |
-
def clean_lora_cache(self):
|
125 |
-
self.unet.clean_lora_cache()
|
126 |
-
|
127 |
-
def get_loaded_lora(self):
|
128 |
-
return self.unet.get_loaded_lora()
|
129 |
|
130 |
def _encode_prompt(
|
131 |
self,
|
@@ -181,13 +122,14 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
181 |
return_tensors="pt",
|
182 |
)
|
183 |
text_input_ids = text_inputs.input_ids
|
184 |
-
untruncated_ids = self.tokenizer(
|
|
|
185 |
|
186 |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
187 |
text_input_ids, untruncated_ids
|
188 |
):
|
189 |
removed_text = self.tokenizer.batch_decode(
|
190 |
-
untruncated_ids[:, self.tokenizer.model_max_length - 1
|
191 |
)
|
192 |
logger.warning(
|
193 |
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
@@ -205,12 +147,14 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
205 |
)
|
206 |
prompt_embeds = prompt_embeds[0]
|
207 |
|
208 |
-
prompt_embeds = prompt_embeds.to(
|
|
|
209 |
|
210 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
211 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
212 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
213 |
-
prompt_embeds = prompt_embeds.view(
|
|
|
214 |
|
215 |
# get unconditional embeddings for classifier free guidance
|
216 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
@@ -235,7 +179,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
235 |
|
236 |
# textual inversion: procecss multi-vector tokens if necessary
|
237 |
if isinstance(self, TextualInversionLoaderMixin):
|
238 |
-
uncond_tokens = self.maybe_convert_prompt(
|
|
|
239 |
|
240 |
max_length = prompt_embeds.shape[1]
|
241 |
uncond_input = self.tokenizer(
|
@@ -261,10 +206,13 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
261 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
262 |
seq_len = negative_prompt_embeds.shape[1]
|
263 |
|
264 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
|
|
265 |
|
266 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
267 |
-
|
|
|
|
|
268 |
|
269 |
# For classifier free guidance, we need to do two forward passes.
|
270 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
@@ -272,7 +220,6 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
272 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
273 |
|
274 |
return prompt_embeds
|
275 |
-
|
276 |
|
277 |
def decode_latents(self, latents):
|
278 |
latents = 1 / self.vae.config.scaling_factor * latents
|
@@ -282,6 +229,17 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
282 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
283 |
return image
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
def check_inputs(
|
286 |
self,
|
287 |
prompt,
|
@@ -291,8 +249,9 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
291 |
prompt_embeds=None,
|
292 |
negative_prompt_embeds=None,
|
293 |
):
|
294 |
-
if height % 64 != 0 or width % 64 != 0:
|
295 |
-
raise ValueError(
|
|
|
296 |
|
297 |
if prompt is not None and prompt_embeds is not None:
|
298 |
raise ValueError(
|
@@ -304,7 +263,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
304 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
305 |
)
|
306 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
307 |
-
raise ValueError(
|
|
|
308 |
|
309 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
310 |
raise ValueError(
|
@@ -342,13 +302,14 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
342 |
|
343 |
elif isinstance(generator, list):
|
344 |
init_latents = [
|
345 |
-
self.vae.encode(image[i: i + 1]).
|
346 |
]
|
347 |
init_latents = torch.cat(init_latents, dim=0)
|
348 |
else:
|
349 |
-
init_latents = self.vae.encode(
|
|
|
350 |
|
351 |
-
init_latents = self.vae.
|
352 |
|
353 |
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
354 |
# expand init_latents for batch_size
|
@@ -358,9 +319,9 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
358 |
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
359 |
" your script to pass as many initial images as text prompts to suppress this warning."
|
360 |
)
|
361 |
-
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
362 |
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
363 |
-
init_latents = torch.cat(
|
|
|
364 |
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
365 |
raise ValueError(
|
366 |
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
@@ -369,7 +330,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
369 |
init_latents = torch.cat([init_latents], dim=0)
|
370 |
|
371 |
shape = init_latents.shape
|
372 |
-
noise = randn_tensor(shape, generator=generator,
|
|
|
373 |
|
374 |
# get latents
|
375 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
@@ -398,7 +360,8 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
398 |
|
399 |
for image_ in image:
|
400 |
image_ = image_.convert("RGB")
|
401 |
-
image_ = image_.resize(
|
|
|
402 |
image_ = np.array(image_)
|
403 |
image_ = image_[None, :]
|
404 |
images.append(image_)
|
@@ -434,27 +397,29 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
434 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
435 |
# and should be between [0, 1]
|
436 |
|
437 |
-
accepts_eta = "eta" in set(inspect.signature(
|
|
|
438 |
extra_step_kwargs = {}
|
439 |
if accepts_eta:
|
440 |
extra_step_kwargs["eta"] = eta
|
441 |
|
442 |
# check if the scheduler accepts generator
|
443 |
-
accepts_generator = "generator" in set(
|
|
|
444 |
if accepts_generator:
|
445 |
extra_step_kwargs["generator"] = generator
|
446 |
return extra_step_kwargs
|
447 |
|
448 |
def get_timesteps(self, num_inference_steps, strength, device):
|
449 |
# get the original timestep using init_timestep
|
450 |
-
init_timestep = min(
|
|
|
451 |
|
452 |
t_start = max(num_inference_steps - init_timestep, 0)
|
453 |
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
|
454 |
|
455 |
return timesteps, num_inference_steps - t_start
|
456 |
|
457 |
-
|
458 |
@torch.no_grad()
|
459 |
def __call__(
|
460 |
self,
|
@@ -477,9 +442,10 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
477 |
controlnet_images: Optional[List[PIL.Image.Image]] = None,
|
478 |
controlnet_scale: Optional[List[float]] = None,
|
479 |
controlnet_names: Optional[List[str]] = None,
|
480 |
-
guess_mode
|
481 |
eta: float = 0.0,
|
482 |
-
generator: Optional[Union[torch.Generator,
|
|
|
483 |
latents: Optional[torch.FloatTensor] = None,
|
484 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
485 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
@@ -549,7 +515,6 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
549 |
# corresponds to doing no classifier free guidance.
|
550 |
do_classifier_free_guidance = guidance_scale > 1.0
|
551 |
|
552 |
-
|
553 |
# 3. Encode input prompt
|
554 |
start = time.perf_counter()
|
555 |
prompt_embeds = self._encode_prompt(
|
@@ -583,17 +548,21 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
583 |
scales = [1.0, ] * 13
|
584 |
if guess_mode:
|
585 |
scales = torch.logspace(-1, 0, 13).tolist()
|
586 |
-
|
587 |
for scale in controlnet_scale:
|
588 |
scales_ = [d * scale for d in scales]
|
589 |
control_scales.append(scales_)
|
590 |
|
591 |
-
|
592 |
-
|
|
|
|
|
593 |
# 5. set timesteps
|
594 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
595 |
-
timesteps, num_inference_steps = self.get_timesteps(
|
596 |
-
|
|
|
|
|
597 |
|
598 |
# 6. Prepare latent variables
|
599 |
latents = self.prepare_latents(
|
@@ -604,33 +573,46 @@ class LyraSdControlnetImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
604 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
605 |
|
606 |
# 8. Denoising loop
|
607 |
-
num_warmup_steps = len(timesteps) -
|
|
|
608 |
|
609 |
start_unet = time.perf_counter()
|
610 |
for i, t in enumerate(timesteps):
|
611 |
# expand the latents if we are doing classifier free guidance
|
612 |
-
latent_model_input = torch.cat(
|
613 |
-
|
614 |
-
latent_model_input =
|
|
|
|
|
|
|
615 |
|
616 |
# 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
|
617 |
-
noise_pred = self.unet.forward(
|
|
|
618 |
|
619 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
620 |
# perform guidance
|
621 |
|
622 |
if do_classifier_free_guidance:
|
623 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
624 |
-
noise_pred = noise_pred_uncond + guidance_scale *
|
|
|
625 |
|
626 |
# compute the previous noisy sample x_t -> x_t-1
|
627 |
-
latents = self.scheduler.step(
|
|
|
628 |
|
629 |
torch.cuda.synchronize()
|
630 |
|
|
|
|
|
|
|
631 |
start = time.perf_counter()
|
632 |
-
image = self.decode_latents(latents)
|
|
|
633 |
torch.cuda.synchronize()
|
|
|
|
|
634 |
image = numpy_to_pil(image)
|
635 |
|
636 |
return image
|
|
|
1 |
import torch
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
3 |
from diffusers.loaders import TextualInversionLoaderMixin
|
4 |
+
from diffusers.utils.torch_utils import logging, randn_tensor
|
|
|
|
|
5 |
from diffusers.utils import PIL_INTERPOLATION
|
6 |
+
|
7 |
import os
|
8 |
import numpy as np
|
9 |
import warnings
|
|
|
|
|
10 |
|
11 |
from PIL import Image
|
12 |
import PIL
|
13 |
|
14 |
+
from .lyrasd_pipeline_base import LyraSDXLPipelineBase
|
15 |
+
|
16 |
import inspect
|
17 |
|
18 |
import time
|
|
|
28 |
images = (images * 255).round().astype("uint8")
|
29 |
if images.shape[-1] == 1:
|
30 |
# special case for grayscale (single channel) images
|
31 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L")
|
32 |
+
for image in images]
|
33 |
else:
|
34 |
pil_images = [Image.fromarray(image) for image in images]
|
35 |
|
|
|
51 |
w, h = image[0].size
|
52 |
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
53 |
|
54 |
+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
|
55 |
+
None, :] for i in image]
|
56 |
image = np.concatenate(image, axis=0)
|
57 |
image = np.array(image).astype(np.float32) / 255.0
|
58 |
image = image.transpose(0, 3, 1, 2)
|
|
|
62 |
image = torch.cat(image, dim=0)
|
63 |
return image
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
class LyraSdControlnetImg2ImgPipeline(LyraSDXLPipelineBase):
|
67 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
|
68 |
+
super().__init__(device, dtype, vae_scale_factor=vae_scale_factor,
|
69 |
+
vae_scaling_factor=vae_scaling_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
def _encode_prompt(
|
72 |
self,
|
|
|
122 |
return_tensors="pt",
|
123 |
)
|
124 |
text_input_ids = text_inputs.input_ids
|
125 |
+
untruncated_ids = self.tokenizer(
|
126 |
+
prompt, padding="longest", return_tensors="pt").input_ids
|
127 |
|
128 |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
129 |
text_input_ids, untruncated_ids
|
130 |
):
|
131 |
removed_text = self.tokenizer.batch_decode(
|
132 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
|
133 |
)
|
134 |
logger.warning(
|
135 |
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
|
147 |
)
|
148 |
prompt_embeds = prompt_embeds[0]
|
149 |
|
150 |
+
prompt_embeds = prompt_embeds.to(
|
151 |
+
dtype=self.text_encoder.dtype, device=device)
|
152 |
|
153 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
154 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
155 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
156 |
+
prompt_embeds = prompt_embeds.view(
|
157 |
+
bs_embed * num_images_per_prompt, seq_len, -1)
|
158 |
|
159 |
# get unconditional embeddings for classifier free guidance
|
160 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
|
179 |
|
180 |
# textual inversion: procecss multi-vector tokens if necessary
|
181 |
if isinstance(self, TextualInversionLoaderMixin):
|
182 |
+
uncond_tokens = self.maybe_convert_prompt(
|
183 |
+
uncond_tokens, self.tokenizer)
|
184 |
|
185 |
max_length = prompt_embeds.shape[1]
|
186 |
uncond_input = self.tokenizer(
|
|
|
206 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
207 |
seq_len = negative_prompt_embeds.shape[1]
|
208 |
|
209 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
210 |
+
dtype=self.text_encoder.dtype, device=device)
|
211 |
|
212 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
213 |
+
1, num_images_per_prompt, 1)
|
214 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
215 |
+
batch_size * num_images_per_prompt, seq_len, -1)
|
216 |
|
217 |
# For classifier free guidance, we need to do two forward passes.
|
218 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
|
220 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
221 |
|
222 |
return prompt_embeds
|
|
|
223 |
|
224 |
def decode_latents(self, latents):
|
225 |
latents = 1 / self.vae.config.scaling_factor * latents
|
|
|
229 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
230 |
return image
|
231 |
|
232 |
+
def lyra_decode_latents(self, latents):
|
233 |
+
print("lyra_decode_latents")
|
234 |
+
latents = 1 / self.vae_scaling_factor * latents
|
235 |
+
image = self.vae.decode(latents)
|
236 |
+
image = image.permute(0, 2, 3, 1)
|
237 |
+
|
238 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
239 |
+
image = image.cpu().float().numpy()
|
240 |
+
|
241 |
+
return image
|
242 |
+
|
243 |
def check_inputs(
|
244 |
self,
|
245 |
prompt,
|
|
|
249 |
prompt_embeds=None,
|
250 |
negative_prompt_embeds=None,
|
251 |
):
|
252 |
+
if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
|
253 |
+
raise ValueError(
|
254 |
+
f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
|
255 |
|
256 |
if prompt is not None and prompt_embeds is not None:
|
257 |
raise ValueError(
|
|
|
263 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
264 |
)
|
265 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
266 |
+
raise ValueError(
|
267 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
268 |
|
269 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
270 |
raise ValueError(
|
|
|
302 |
|
303 |
elif isinstance(generator, list):
|
304 |
init_latents = [
|
305 |
+
self.vae.encode(image[i: i + 1]).sample(generator[i]) for i in range(batch_size)
|
306 |
]
|
307 |
init_latents = torch.cat(init_latents, dim=0)
|
308 |
else:
|
309 |
+
init_latents = self.vae.encode(
|
310 |
+
image).sample(generator)
|
311 |
|
312 |
+
init_latents = self.vae.scaling_factor * init_latents
|
313 |
|
314 |
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
315 |
# expand init_latents for batch_size
|
|
|
319 |
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
320 |
" your script to pass as many initial images as text prompts to suppress this warning."
|
321 |
)
|
|
|
322 |
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
323 |
+
init_latents = torch.cat(
|
324 |
+
[init_latents] * additional_image_per_prompt, dim=0)
|
325 |
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
326 |
raise ValueError(
|
327 |
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
|
|
330 |
init_latents = torch.cat([init_latents], dim=0)
|
331 |
|
332 |
shape = init_latents.shape
|
333 |
+
noise = randn_tensor(shape, generator=generator,
|
334 |
+
device=device, dtype=dtype)
|
335 |
|
336 |
# get latents
|
337 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
|
|
360 |
|
361 |
for image_ in image:
|
362 |
image_ = image_.convert("RGB")
|
363 |
+
image_ = image_.resize(
|
364 |
+
(width, height), resample=PIL_INTERPOLATION["lanczos"])
|
365 |
image_ = np.array(image_)
|
366 |
image_ = image_[None, :]
|
367 |
images.append(image_)
|
|
|
397 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
398 |
# and should be between [0, 1]
|
399 |
|
400 |
+
accepts_eta = "eta" in set(inspect.signature(
|
401 |
+
self.scheduler.step).parameters.keys())
|
402 |
extra_step_kwargs = {}
|
403 |
if accepts_eta:
|
404 |
extra_step_kwargs["eta"] = eta
|
405 |
|
406 |
# check if the scheduler accepts generator
|
407 |
+
accepts_generator = "generator" in set(
|
408 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
409 |
if accepts_generator:
|
410 |
extra_step_kwargs["generator"] = generator
|
411 |
return extra_step_kwargs
|
412 |
|
413 |
def get_timesteps(self, num_inference_steps, strength, device):
|
414 |
# get the original timestep using init_timestep
|
415 |
+
init_timestep = min(
|
416 |
+
int(num_inference_steps * strength), num_inference_steps)
|
417 |
|
418 |
t_start = max(num_inference_steps - init_timestep, 0)
|
419 |
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
|
420 |
|
421 |
return timesteps, num_inference_steps - t_start
|
422 |
|
|
|
423 |
@torch.no_grad()
|
424 |
def __call__(
|
425 |
self,
|
|
|
442 |
controlnet_images: Optional[List[PIL.Image.Image]] = None,
|
443 |
controlnet_scale: Optional[List[float]] = None,
|
444 |
controlnet_names: Optional[List[str]] = None,
|
445 |
+
guess_mode=False,
|
446 |
eta: float = 0.0,
|
447 |
+
generator: Optional[Union[torch.Generator,
|
448 |
+
List[torch.Generator]]] = None,
|
449 |
latents: Optional[torch.FloatTensor] = None,
|
450 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
451 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
515 |
# corresponds to doing no classifier free guidance.
|
516 |
do_classifier_free_guidance = guidance_scale > 1.0
|
517 |
|
|
|
518 |
# 3. Encode input prompt
|
519 |
start = time.perf_counter()
|
520 |
prompt_embeds = self._encode_prompt(
|
|
|
548 |
scales = [1.0, ] * 13
|
549 |
if guess_mode:
|
550 |
scales = torch.logspace(-1, 0, 13).tolist()
|
551 |
+
|
552 |
for scale in controlnet_scale:
|
553 |
scales_ = [d * scale for d in scales]
|
554 |
control_scales.append(scales_)
|
555 |
|
556 |
+
print(f"clip cost: {(time.perf_counter() - start)* 1000}")
|
557 |
+
|
558 |
+
image = self.image_processor.preprocess(image)
|
559 |
+
|
560 |
# 5. set timesteps
|
561 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
562 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
563 |
+
num_inference_steps, strength, device)
|
564 |
+
latent_timestep = timesteps[:1].repeat(
|
565 |
+
batch_size * num_images_per_prompt)
|
566 |
|
567 |
# 6. Prepare latent variables
|
568 |
latents = self.prepare_latents(
|
|
|
573 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
574 |
|
575 |
# 8. Denoising loop
|
576 |
+
num_warmup_steps = len(timesteps) - \
|
577 |
+
num_inference_steps * self.scheduler.order
|
578 |
|
579 |
start_unet = time.perf_counter()
|
580 |
for i, t in enumerate(timesteps):
|
581 |
# expand the latents if we are doing classifier free guidance
|
582 |
+
latent_model_input = torch.cat(
|
583 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
584 |
+
latent_model_input = self.scheduler.scale_model_input(
|
585 |
+
latent_model_input, t)
|
586 |
+
latent_model_input = latent_model_input.permute(
|
587 |
+
0, 2, 3, 1).contiguous()
|
588 |
|
589 |
# 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
|
590 |
+
noise_pred = self.unet.forward(
|
591 |
+
latent_model_input, prompt_embeds, t, controlnet_names, control_images, control_scales, guess_mode)
|
592 |
|
593 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
594 |
# perform guidance
|
595 |
|
596 |
if do_classifier_free_guidance:
|
597 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
598 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
599 |
+
(noise_pred_text - noise_pred_uncond)
|
600 |
|
601 |
# compute the previous noisy sample x_t -> x_t-1
|
602 |
+
latents = self.scheduler.step(
|
603 |
+
noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
604 |
|
605 |
torch.cuda.synchronize()
|
606 |
|
607 |
+
print(
|
608 |
+
f"unet x {num_inference_steps} cost: {(time.perf_counter() - start_unet) * 1000}")
|
609 |
+
|
610 |
start = time.perf_counter()
|
611 |
+
# image = self.decode_latents(latents)
|
612 |
+
image = self.lyra_decode_latents(latents)
|
613 |
torch.cuda.synchronize()
|
614 |
+
print(f"vae cost: {(time.perf_counter() - start)* 1000}")
|
615 |
+
print()
|
616 |
image = numpy_to_pil(image)
|
617 |
|
618 |
return image
|
lyrasd_model/lyrasd_controlnet_txt2img_pipeline.py
CHANGED
@@ -1,12 +1,8 @@
|
|
1 |
import torch
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
-
from diffusers.schedulers import KarrasDiffusionSchedulers
|
4 |
from diffusers.loaders import TextualInversionLoaderMixin
|
5 |
-
from diffusers.
|
6 |
-
from diffusers.utils import randn_tensor, logging
|
7 |
-
from diffusers.schedulers import EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
8 |
from diffusers.utils import PIL_INTERPOLATION
|
9 |
-
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
10 |
import os
|
11 |
import numpy as np
|
12 |
from .lora_util import add_text_lora_layer
|
@@ -17,6 +13,7 @@ import PIL
|
|
17 |
import inspect
|
18 |
|
19 |
import time
|
|
|
20 |
|
21 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
22 |
|
@@ -36,68 +33,11 @@ def numpy_to_pil(images):
|
|
36 |
return pil_images
|
37 |
|
38 |
|
39 |
-
class LyraSdControlnetTxt2ImgPipeline(
|
40 |
-
def __init__(self,
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
torch.classes.load_library(lib_so_path)
|
45 |
-
|
46 |
-
self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
|
47 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
48 |
-
self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
|
49 |
-
self.unet_in_channels = 4
|
50 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
51 |
-
self.vae.enable_tiling()
|
52 |
-
self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
|
53 |
-
3, # max num of controlnets
|
54 |
-
"fp16" # inference dtype (can only use fp16 for now)
|
55 |
-
)
|
56 |
-
|
57 |
-
unet_path = os.path.join(model_path, "unet_bins/")
|
58 |
-
self.reload_unet_model(unet_path, model_dtype)
|
59 |
-
|
60 |
-
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
61 |
-
|
62 |
-
def load_controlnet_model(self, model_name, controlnet_path, model_dtype="fp32"):
|
63 |
-
if len(controlnet_path) > 0 and controlnet_path[-1] != "/":
|
64 |
-
controlnet_path = controlnet_path + "/"
|
65 |
-
self.unet.load_controlnet_model(model_name, controlnet_path, model_dtype)
|
66 |
-
|
67 |
-
def unload_controlnet_model(self, model_name):
|
68 |
-
self.unet.unload_controlnet_model(model_name, True)
|
69 |
-
|
70 |
-
def get_loaded_controlnet(self):
|
71 |
-
return self.unet.get_loaded_controlnet()
|
72 |
-
|
73 |
-
def reload_unet_model(self, unet_path, unet_file_format='fp32'):
|
74 |
-
if len(unet_path) > 0 and unet_path[-1] != "/":
|
75 |
-
unet_path = unet_path + "/"
|
76 |
-
return self.unet.reload_unet_model(unet_path, unet_file_format)
|
77 |
-
|
78 |
-
def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
|
79 |
-
if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
|
80 |
-
lora_model_path = lora_model_path + "/"
|
81 |
-
lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
|
82 |
-
self.loaded_lora[lora_name] = lora
|
83 |
-
self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
|
84 |
-
|
85 |
-
def unload_lora(self, lora_name, clean_cache=False):
|
86 |
-
for layer_data in self.loaded_lora[lora_name]:
|
87 |
-
layer = layer_data['layer']
|
88 |
-
added_weight = layer_data['added_weight']
|
89 |
-
layer.weight.data -= added_weight
|
90 |
-
self.unet.unload_lora(lora_name, clean_cache)
|
91 |
-
del self.loaded_lora[lora_name]
|
92 |
-
gc.collect()
|
93 |
-
torch.cuda.empty_cache()
|
94 |
-
|
95 |
-
def clean_lora_cache(self):
|
96 |
-
self.unet.clean_lora_cache()
|
97 |
-
|
98 |
-
def get_loaded_lora(self):
|
99 |
-
return self.unet.get_loaded_lora()
|
100 |
-
|
101 |
def _encode_prompt(
|
102 |
self,
|
103 |
prompt,
|
@@ -253,6 +193,23 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
253 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
254 |
return image
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
def check_inputs(
|
257 |
self,
|
258 |
prompt,
|
@@ -342,21 +299,8 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
342 |
elif isinstance(image[0], torch.Tensor):
|
343 |
image = torch.cat(image, dim=0)
|
344 |
|
345 |
-
image_batch_size = image.shape[0]
|
346 |
-
|
347 |
-
if image_batch_size == 1:
|
348 |
-
repeat_by = batch_size
|
349 |
-
else:
|
350 |
-
# image batch size is the same as prompt batch size
|
351 |
-
repeat_by = num_images_per_prompt
|
352 |
-
|
353 |
-
image = image.repeat_interleave(repeat_by, dim=0)
|
354 |
-
|
355 |
image = image.to(device=device, dtype=dtype)
|
356 |
|
357 |
-
if do_classifier_free_guidance and not guess_mode:
|
358 |
-
image = torch.cat([image] * 2)
|
359 |
-
|
360 |
return image
|
361 |
|
362 |
def prepare_extra_step_kwargs(self, generator, eta):
|
@@ -376,6 +320,18 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
376 |
extra_step_kwargs["generator"] = generator
|
377 |
return extra_step_kwargs
|
378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
@torch.no_grad()
|
380 |
def __call__(
|
381 |
self,
|
@@ -527,7 +483,7 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
527 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
528 |
latent_model_input = latent_model_input.permute(0, 2, 3, 1).contiguous()
|
529 |
|
530 |
-
|
531 |
noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, controlnet_names, control_images, control_scales, guess_mode)
|
532 |
|
533 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
@@ -540,7 +496,9 @@ class LyraSdControlnetTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
540 |
# compute the previous noisy sample x_t -> x_t-1
|
541 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
542 |
|
543 |
-
image = self.decode_latents(latents)
|
|
|
|
|
544 |
image = numpy_to_pil(image)
|
545 |
|
546 |
return image
|
|
|
1 |
import torch
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
3 |
from diffusers.loaders import TextualInversionLoaderMixin
|
4 |
+
from diffusers.utils.torch_utils import logging, randn_tensor
|
|
|
|
|
5 |
from diffusers.utils import PIL_INTERPOLATION
|
|
|
6 |
import os
|
7 |
import numpy as np
|
8 |
from .lora_util import add_text_lora_layer
|
|
|
13 |
import inspect
|
14 |
|
15 |
import time
|
16 |
+
from .lyrasd_pipeline_base import LyraSDXLPipelineBase
|
17 |
|
18 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
19 |
|
|
|
33 |
return pil_images
|
34 |
|
35 |
|
36 |
+
class LyraSdControlnetTxt2ImgPipeline(LyraSDXLPipelineBase):
|
37 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
|
38 |
+
super().__init__(device, dtype, vae_scale_factor=vae_scale_factor,
|
39 |
+
vae_scaling_factor=vae_scaling_factor)
|
40 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def _encode_prompt(
|
42 |
self,
|
43 |
prompt,
|
|
|
193 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
194 |
return image
|
195 |
|
196 |
+
def lyra_decode_latents(self, latents):
|
197 |
+
print("lyra_decode_latents")
|
198 |
+
# np.save("", latents.)
|
199 |
+
# np.save(f"/workspace/vae_model/latent.npy", latents.detach().cpu().numpy())
|
200 |
+
latents = 1 / self.vae_scaling_factor * latents
|
201 |
+
latents = latents.permute(0, 2, 3, 1).contiguous()
|
202 |
+
image = self.vae.vae_decode(latents)
|
203 |
+
|
204 |
+
# print(image)
|
205 |
+
# GPUtil.showUtilization(all=True)
|
206 |
+
|
207 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
208 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
209 |
+
image = image.cpu().float().numpy()
|
210 |
+
|
211 |
+
return image
|
212 |
+
|
213 |
def check_inputs(
|
214 |
self,
|
215 |
prompt,
|
|
|
299 |
elif isinstance(image[0], torch.Tensor):
|
300 |
image = torch.cat(image, dim=0)
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
image = image.to(device=device, dtype=dtype)
|
303 |
|
|
|
|
|
|
|
304 |
return image
|
305 |
|
306 |
def prepare_extra_step_kwargs(self, generator, eta):
|
|
|
320 |
extra_step_kwargs["generator"] = generator
|
321 |
return extra_step_kwargs
|
322 |
|
323 |
+
def lyra_decode_latents(self, latents):
|
324 |
+
print("lyra_decode_latents")
|
325 |
+
latents = 1 / self.vae_scaling_factor * latents
|
326 |
+
image = self.vae.decode(latents)
|
327 |
+
image = image.permute(0, 2, 3, 1)
|
328 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
329 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
330 |
+
image = image.cpu().float().numpy()
|
331 |
+
|
332 |
+
return image
|
333 |
+
|
334 |
+
|
335 |
@torch.no_grad()
|
336 |
def __call__(
|
337 |
self,
|
|
|
483 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
484 |
latent_model_input = latent_model_input.permute(0, 2, 3, 1).contiguous()
|
485 |
|
486 |
+
control_images[0]
|
487 |
noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, controlnet_names, control_images, control_scales, guess_mode)
|
488 |
|
489 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
|
|
496 |
# compute the previous noisy sample x_t -> x_t-1
|
497 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
498 |
|
499 |
+
# image = self.decode_latents(latents)
|
500 |
+
image = self.lyra_decode_latents(latents)
|
501 |
+
|
502 |
image = numpy_to_pil(image)
|
503 |
|
504 |
return image
|
lyrasd_model/lyrasd_img2img_pipeline.py
CHANGED
@@ -8,13 +8,12 @@ import numpy as np
|
|
8 |
import PIL
|
9 |
import torch
|
10 |
from diffusers.loaders import TextualInversionLoaderMixin
|
11 |
-
from diffusers.
|
12 |
-
from diffusers.
|
13 |
-
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
|
14 |
from PIL import Image
|
15 |
-
|
16 |
-
from .
|
17 |
-
|
18 |
|
19 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
20 |
|
@@ -28,7 +27,8 @@ def numpy_to_pil(images):
|
|
28 |
images = (images * 255).round().astype("uint8")
|
29 |
if images.shape[-1] == 1:
|
30 |
# special case for grayscale (single channel) images
|
31 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L")
|
|
|
32 |
else:
|
33 |
pil_images = [Image.fromarray(image) for image in images]
|
34 |
|
@@ -50,7 +50,8 @@ def preprocess(image):
|
|
50 |
w, h = image[0].size
|
51 |
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
52 |
|
53 |
-
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
|
|
|
54 |
image = np.concatenate(image, axis=0)
|
55 |
image = np.array(image).astype(np.float32) / 255.0
|
56 |
image = image.transpose(0, 3, 1, 2)
|
@@ -61,60 +62,13 @@ def preprocess(image):
|
|
61 |
return image
|
62 |
|
63 |
|
64 |
-
class LyraSDImg2ImgPipeline(
|
65 |
-
def __init__(self,
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
torch.classes.load_library(lib_so_path)
|
70 |
-
|
71 |
-
self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
|
72 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
73 |
-
self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
|
74 |
-
unet_path = os.path.join(model_path, "unet_bins/")
|
75 |
-
|
76 |
-
self.unet_in_channels = 4
|
77 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
78 |
-
self.vae.enable_tiling()
|
79 |
-
self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
|
80 |
-
3, # max num of controlnets
|
81 |
-
"fp16" # inference dtype (can only use fp16 for now)
|
82 |
-
)
|
83 |
-
|
84 |
-
self.reload_unet_model(unet_path, model_dtype)
|
85 |
-
|
86 |
-
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
87 |
-
|
88 |
-
def reload_unet_model(self, unet_path, unet_file_format='fp32'):
|
89 |
-
if len(unet_path) > 0 and unet_path[-1] != "/":
|
90 |
-
unet_path = unet_path + "/"
|
91 |
-
return self.unet.reload_unet_model(unet_path, unet_file_format)
|
92 |
-
|
93 |
-
def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
|
94 |
-
if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
|
95 |
-
lora_model_path = lora_model_path + "/"
|
96 |
-
lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
|
97 |
-
self.loaded_lora[lora_name] = lora
|
98 |
-
self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
|
99 |
-
|
100 |
-
def unload_lora(self, lora_name, clean_cache=False):
|
101 |
-
for layer_data in self.loaded_lora[lora_name]:
|
102 |
-
layer = layer_data['layer']
|
103 |
-
added_weight = layer_data['added_weight']
|
104 |
-
layer.weight.data -= added_weight
|
105 |
-
self.unet.unload_lora(lora_name, clean_cache)
|
106 |
-
del self.loaded_lora[lora_name]
|
107 |
-
gc.collect()
|
108 |
-
torch.cuda.empty_cache()
|
109 |
-
|
110 |
-
def clean_lora_cache(self):
|
111 |
-
self.unet.clean_lora_cache()
|
112 |
-
|
113 |
-
def get_loaded_lora(self):
|
114 |
-
return self.unet.get_loaded_lora()
|
115 |
-
|
116 |
|
117 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
|
|
118 |
def _encode_prompt(
|
119 |
self,
|
120 |
prompt,
|
@@ -170,7 +124,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
170 |
return_tensors="pt",
|
171 |
)
|
172 |
text_input_ids = text_inputs.input_ids
|
173 |
-
untruncated_ids = self.tokenizer(
|
|
|
174 |
|
175 |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
176 |
text_input_ids, untruncated_ids
|
@@ -201,12 +156,14 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
201 |
else:
|
202 |
prompt_embeds_dtype = prompt_embeds.dtype
|
203 |
|
204 |
-
prompt_embeds = prompt_embeds.to(
|
|
|
205 |
|
206 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
207 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
208 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
209 |
-
prompt_embeds = prompt_embeds.view(
|
|
|
210 |
|
211 |
# get unconditional embeddings for classifier free guidance
|
212 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
@@ -231,7 +188,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
231 |
|
232 |
# textual inversion: procecss multi-vector tokens if necessary
|
233 |
if isinstance(self, TextualInversionLoaderMixin):
|
234 |
-
uncond_tokens = self.maybe_convert_prompt(
|
|
|
235 |
|
236 |
max_length = prompt_embeds.shape[1]
|
237 |
uncond_input = self.tokenizer(
|
@@ -257,10 +215,13 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
257 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
258 |
seq_len = negative_prompt_embeds.shape[1]
|
259 |
|
260 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
|
|
261 |
|
262 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
263 |
-
|
|
|
|
|
264 |
|
265 |
# For classifier free guidance, we need to do two forward passes.
|
266 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
@@ -286,13 +247,15 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
286 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
287 |
# and should be between [0, 1]
|
288 |
|
289 |
-
accepts_eta = "eta" in set(inspect.signature(
|
|
|
290 |
extra_step_kwargs = {}
|
291 |
if accepts_eta:
|
292 |
extra_step_kwargs["eta"] = eta
|
293 |
|
294 |
# check if the scheduler accepts generator
|
295 |
-
accepts_generator = "generator" in set(
|
|
|
296 |
if accepts_generator:
|
297 |
extra_step_kwargs["generator"] = generator
|
298 |
return extra_step_kwargs
|
@@ -301,10 +264,12 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
301 |
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
|
302 |
):
|
303 |
if strength < 0 or strength > 1:
|
304 |
-
raise ValueError(
|
|
|
305 |
|
306 |
if (callback_steps is None) or (
|
307 |
-
callback_steps is not None and (not isinstance(
|
|
|
308 |
):
|
309 |
raise ValueError(
|
310 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
@@ -321,7 +286,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
321 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
322 |
)
|
323 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
324 |
-
raise ValueError(
|
|
|
325 |
|
326 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
327 |
raise ValueError(
|
@@ -339,7 +305,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
339 |
|
340 |
def get_timesteps(self, num_inference_steps, strength, device):
|
341 |
# get the original timestep using init_timestep
|
342 |
-
init_timestep = min(
|
|
|
343 |
|
344 |
t_start = max(num_inference_steps - init_timestep, 0)
|
345 |
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
|
@@ -354,6 +321,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
354 |
|
355 |
image = image.to(device=device, dtype=dtype)
|
356 |
|
|
|
|
|
357 |
batch_size = batch_size * num_images_per_prompt
|
358 |
|
359 |
if image.shape[1] == 4:
|
@@ -368,13 +337,13 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
368 |
|
369 |
elif isinstance(generator, list):
|
370 |
init_latents = [
|
371 |
-
self.vae.encode(image[i: i + 1]).
|
372 |
]
|
373 |
init_latents = torch.cat(init_latents, dim=0)
|
374 |
else:
|
375 |
-
init_latents = self.vae.encode(image).
|
376 |
|
377 |
-
init_latents = self.vae.
|
378 |
|
379 |
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
380 |
# expand init_latents for batch_size
|
@@ -384,9 +353,11 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
384 |
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
385 |
" your script to pass as many initial images as text prompts to suppress this warning."
|
386 |
)
|
387 |
-
deprecate("len(prompt) != len(image)", "1.0.0",
|
|
|
388 |
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
389 |
-
init_latents = torch.cat(
|
|
|
390 |
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
391 |
raise ValueError(
|
392 |
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
@@ -395,7 +366,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
395 |
init_latents = torch.cat([init_latents], dim=0)
|
396 |
|
397 |
shape = init_latents.shape
|
398 |
-
noise = randn_tensor(shape, generator=generator,
|
|
|
399 |
|
400 |
# get latents
|
401 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
@@ -403,6 +375,17 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
403 |
|
404 |
return latents
|
405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
@torch.no_grad()
|
407 |
def __call__(
|
408 |
self,
|
@@ -421,10 +404,12 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
421 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
422 |
num_images_per_prompt: Optional[int] = 1,
|
423 |
eta: Optional[float] = 0.0,
|
424 |
-
generator: Optional[Union[torch.Generator,
|
|
|
425 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
426 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
427 |
-
callback: Optional[Callable[[
|
|
|
428 |
callback_steps: int = 1,
|
429 |
):
|
430 |
r"""
|
@@ -482,7 +467,8 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
482 |
"not-safe-for-work" (nsfw) content.
|
483 |
"""
|
484 |
# 1. Check inputs. Raise error if not correct
|
485 |
-
self.check_inputs(prompt, strength, callback_steps,
|
|
|
486 |
|
487 |
# 2. Define call parameters
|
488 |
if prompt is not None and isinstance(prompt, str):
|
@@ -510,12 +496,14 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
510 |
)
|
511 |
|
512 |
# 4. Preprocess image
|
513 |
-
image = preprocess(image)
|
514 |
|
515 |
# 5. set timesteps
|
516 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
517 |
-
timesteps, num_inference_steps = self.get_timesteps(
|
518 |
-
|
|
|
|
|
519 |
|
520 |
# 6. Prepare latent variables
|
521 |
latents = self.prepare_latents(
|
@@ -526,29 +514,36 @@ class LyraSDImg2ImgPipeline(TextualInversionLoaderMixin):
|
|
526 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
527 |
|
528 |
# 8. Denoising loop
|
529 |
-
num_warmup_steps = len(timesteps) -
|
|
|
530 |
|
531 |
for i, t in enumerate(timesteps):
|
532 |
# expand the latents if we are doing classifier free guidance
|
533 |
-
latent_model_input = torch.cat(
|
534 |
-
|
535 |
-
latent_model_input =
|
|
|
|
|
|
|
536 |
|
537 |
# predict the noise residual
|
538 |
-
# 后边
|
539 |
-
noise_pred = self.unet.forward(
|
540 |
-
|
541 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
542 |
|
543 |
# perform guidance
|
544 |
if do_classifier_free_guidance:
|
545 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
546 |
-
noise_pred = noise_pred_uncond + guidance_scale *
|
|
|
547 |
|
548 |
# compute the previous noisy sample x_t -> x_t-1
|
549 |
-
latents = self.scheduler.step(
|
|
|
550 |
|
551 |
-
image = self.decode_latents(latents)
|
|
|
552 |
image = numpy_to_pil(image)
|
553 |
|
554 |
return image
|
|
|
8 |
import PIL
|
9 |
import torch
|
10 |
from diffusers.loaders import TextualInversionLoaderMixin
|
11 |
+
from diffusers.utils import PIL_INTERPOLATION, deprecate
|
12 |
+
from diffusers.utils.torch_utils import logging, randn_tensor
|
|
|
13 |
from PIL import Image
|
14 |
+
|
15 |
+
from .lyrasd_pipeline_base import LyraSDXLPipelineBase
|
16 |
+
|
17 |
|
18 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
19 |
|
|
|
27 |
images = (images * 255).round().astype("uint8")
|
28 |
if images.shape[-1] == 1:
|
29 |
# special case for grayscale (single channel) images
|
30 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L")
|
31 |
+
for image in images]
|
32 |
else:
|
33 |
pil_images = [Image.fromarray(image) for image in images]
|
34 |
|
|
|
50 |
w, h = image[0].size
|
51 |
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
52 |
|
53 |
+
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[
|
54 |
+
None, :] for i in image]
|
55 |
image = np.concatenate(image, axis=0)
|
56 |
image = np.array(image).astype(np.float32) / 255.0
|
57 |
image = image.transpose(0, 3, 1, 2)
|
|
|
62 |
return image
|
63 |
|
64 |
|
65 |
+
class LyraSDImg2ImgPipeline(LyraSDXLPipelineBase):
|
66 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
|
67 |
+
super().__init__(device, dtype, vae_scale_factor=vae_scale_factor,
|
68 |
+
vae_scaling_factor=vae_scaling_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
71 |
+
|
72 |
def _encode_prompt(
|
73 |
self,
|
74 |
prompt,
|
|
|
124 |
return_tensors="pt",
|
125 |
)
|
126 |
text_input_ids = text_inputs.input_ids
|
127 |
+
untruncated_ids = self.tokenizer(
|
128 |
+
prompt, padding="longest", return_tensors="pt").input_ids
|
129 |
|
130 |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
131 |
text_input_ids, untruncated_ids
|
|
|
156 |
else:
|
157 |
prompt_embeds_dtype = prompt_embeds.dtype
|
158 |
|
159 |
+
prompt_embeds = prompt_embeds.to(
|
160 |
+
dtype=prompt_embeds_dtype, device=device)
|
161 |
|
162 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
163 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
164 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
165 |
+
prompt_embeds = prompt_embeds.view(
|
166 |
+
bs_embed * num_images_per_prompt, seq_len, -1)
|
167 |
|
168 |
# get unconditional embeddings for classifier free guidance
|
169 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
|
188 |
|
189 |
# textual inversion: procecss multi-vector tokens if necessary
|
190 |
if isinstance(self, TextualInversionLoaderMixin):
|
191 |
+
uncond_tokens = self.maybe_convert_prompt(
|
192 |
+
uncond_tokens, self.tokenizer)
|
193 |
|
194 |
max_length = prompt_embeds.shape[1]
|
195 |
uncond_input = self.tokenizer(
|
|
|
215 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
216 |
seq_len = negative_prompt_embeds.shape[1]
|
217 |
|
218 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
219 |
+
dtype=prompt_embeds_dtype, device=device)
|
220 |
|
221 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
222 |
+
1, num_images_per_prompt, 1)
|
223 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
224 |
+
batch_size * num_images_per_prompt, seq_len, -1)
|
225 |
|
226 |
# For classifier free guidance, we need to do two forward passes.
|
227 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
|
247 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
248 |
# and should be between [0, 1]
|
249 |
|
250 |
+
accepts_eta = "eta" in set(inspect.signature(
|
251 |
+
self.scheduler.step).parameters.keys())
|
252 |
extra_step_kwargs = {}
|
253 |
if accepts_eta:
|
254 |
extra_step_kwargs["eta"] = eta
|
255 |
|
256 |
# check if the scheduler accepts generator
|
257 |
+
accepts_generator = "generator" in set(
|
258 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
259 |
if accepts_generator:
|
260 |
extra_step_kwargs["generator"] = generator
|
261 |
return extra_step_kwargs
|
|
|
264 |
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
|
265 |
):
|
266 |
if strength < 0 or strength > 1:
|
267 |
+
raise ValueError(
|
268 |
+
f"The value of strength should in [0.0, 1.0] but is {strength}")
|
269 |
|
270 |
if (callback_steps is None) or (
|
271 |
+
callback_steps is not None and (not isinstance(
|
272 |
+
callback_steps, int) or callback_steps <= 0)
|
273 |
):
|
274 |
raise ValueError(
|
275 |
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
|
286 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
287 |
)
|
288 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
289 |
+
raise ValueError(
|
290 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
291 |
|
292 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
293 |
raise ValueError(
|
|
|
305 |
|
306 |
def get_timesteps(self, num_inference_steps, strength, device):
|
307 |
# get the original timestep using init_timestep
|
308 |
+
init_timestep = min(
|
309 |
+
int(num_inference_steps * strength), num_inference_steps)
|
310 |
|
311 |
t_start = max(num_inference_steps - init_timestep, 0)
|
312 |
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
|
|
|
321 |
|
322 |
image = image.to(device=device, dtype=dtype)
|
323 |
|
324 |
+
print(image.shape)
|
325 |
+
|
326 |
batch_size = batch_size * num_images_per_prompt
|
327 |
|
328 |
if image.shape[1] == 4:
|
|
|
337 |
|
338 |
elif isinstance(generator, list):
|
339 |
init_latents = [
|
340 |
+
self.vae.encode(image[i: i + 1]).sample(generator[i]) for i in range(batch_size)
|
341 |
]
|
342 |
init_latents = torch.cat(init_latents, dim=0)
|
343 |
else:
|
344 |
+
init_latents = self.vae.encode(image).sample(generator)
|
345 |
|
346 |
+
init_latents = self.vae.scaling_factor * init_latents
|
347 |
|
348 |
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
349 |
# expand init_latents for batch_size
|
|
|
353 |
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
354 |
" your script to pass as many initial images as text prompts to suppress this warning."
|
355 |
)
|
356 |
+
deprecate("len(prompt) != len(image)", "1.0.0",
|
357 |
+
deprecation_message, standard_warn=False)
|
358 |
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
359 |
+
init_latents = torch.cat(
|
360 |
+
[init_latents] * additional_image_per_prompt, dim=0)
|
361 |
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
362 |
raise ValueError(
|
363 |
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
|
|
366 |
init_latents = torch.cat([init_latents], dim=0)
|
367 |
|
368 |
shape = init_latents.shape
|
369 |
+
noise = randn_tensor(shape, generator=generator,
|
370 |
+
device=device, dtype=dtype)
|
371 |
|
372 |
# get latents
|
373 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
|
|
375 |
|
376 |
return latents
|
377 |
|
378 |
+
def lyra_decode_latents(self, latents):
|
379 |
+
print("lyra_decode_latents")
|
380 |
+
latents = 1 / self.vae_scaling_factor * latents
|
381 |
+
image = self.vae.decode(latents)
|
382 |
+
image = image.permute(0, 2, 3, 1)
|
383 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
384 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
385 |
+
image = image.cpu().float().numpy()
|
386 |
+
|
387 |
+
return image
|
388 |
+
|
389 |
@torch.no_grad()
|
390 |
def __call__(
|
391 |
self,
|
|
|
404 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
405 |
num_images_per_prompt: Optional[int] = 1,
|
406 |
eta: Optional[float] = 0.0,
|
407 |
+
generator: Optional[Union[torch.Generator,
|
408 |
+
List[torch.Generator]]] = None,
|
409 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
410 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
411 |
+
callback: Optional[Callable[[
|
412 |
+
int, int, torch.FloatTensor], None]] = None,
|
413 |
callback_steps: int = 1,
|
414 |
):
|
415 |
r"""
|
|
|
467 |
"not-safe-for-work" (nsfw) content.
|
468 |
"""
|
469 |
# 1. Check inputs. Raise error if not correct
|
470 |
+
self.check_inputs(prompt, strength, callback_steps,
|
471 |
+
negative_prompt, prompt_embeds, negative_prompt_embeds)
|
472 |
|
473 |
# 2. Define call parameters
|
474 |
if prompt is not None and isinstance(prompt, str):
|
|
|
496 |
)
|
497 |
|
498 |
# 4. Preprocess image
|
499 |
+
image = self.image_processor.preprocess(image)
|
500 |
|
501 |
# 5. set timesteps
|
502 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
503 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
504 |
+
num_inference_steps, strength, device)
|
505 |
+
latent_timestep = timesteps[:1].repeat(
|
506 |
+
batch_size * num_images_per_prompt)
|
507 |
|
508 |
# 6. Prepare latent variables
|
509 |
latents = self.prepare_latents(
|
|
|
514 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
515 |
|
516 |
# 8. Denoising loop
|
517 |
+
num_warmup_steps = len(timesteps) - \
|
518 |
+
num_inference_steps * self.scheduler.order
|
519 |
|
520 |
for i, t in enumerate(timesteps):
|
521 |
# expand the latents if we are doing classifier free guidance
|
522 |
+
latent_model_input = torch.cat(
|
523 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
524 |
+
latent_model_input = self.scheduler.scale_model_input(
|
525 |
+
latent_model_input, t)
|
526 |
+
latent_model_input = latent_model_input.permute(
|
527 |
+
0, 2, 3, 1).contiguous()
|
528 |
|
529 |
# predict the noise residual
|
530 |
+
# 后边 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
|
531 |
+
noise_pred = self.unet.forward(
|
532 |
+
latent_model_input, prompt_embeds, t)
|
533 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
534 |
|
535 |
# perform guidance
|
536 |
if do_classifier_free_guidance:
|
537 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
538 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
539 |
+
(noise_pred_text - noise_pred_uncond)
|
540 |
|
541 |
# compute the previous noisy sample x_t -> x_t-1
|
542 |
+
latents = self.scheduler.step(
|
543 |
+
noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
544 |
|
545 |
+
# image = self.decode_latents(latents)
|
546 |
+
image = self.lyra_decode_latents(latents)
|
547 |
image = numpy_to_pil(image)
|
548 |
|
549 |
return image
|
lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm80.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0689ed5d3b55f5033a8869d5f23ce900793aa0ab7fdc4a3e3c0a0f3a243c83da
|
3 |
-
size 65441456
|
|
|
|
|
|
|
|
lyrasd_model/lyrasd_lib/libth_lyrasd_cu11_sm86.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b8e27e715fa3a17ce25bf23b772e0dd355d0780c1bd93cfeeb12ef45b0ba2444
|
3 |
-
size 65389176
|
|
|
|
|
|
|
|
lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8600f5414d283ebf64cb3974ef520858747cbb1a6d59dd46a3dcd9427758613b
|
3 |
+
size 97823240
|
lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm86.so
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8e5aefbb32667eeacb7fa60283656b4bb2ebb7dcd54276f9d101c856ed64e340
|
3 |
+
size 97823240
|
lyrasd_model/lyrasd_pipeline_base.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
12 |
+
from diffusers.image_processor import VaeImageProcessor
|
13 |
+
from diffusers.models import AutoencoderKL
|
14 |
+
from diffusers.schedulers import (DPMSolverMultistepScheduler,
|
15 |
+
EulerAncestralDiscreteScheduler,
|
16 |
+
EulerDiscreteScheduler,
|
17 |
+
KarrasDiffusionSchedulers)
|
18 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
19 |
+
from .lyrasd_vae_model import LyraSdVaeModel
|
20 |
+
from .module.lyrasd_ip_adapter import LyraIPAdapter
|
21 |
+
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
|
22 |
+
from safetensors.torch import load_file
|
23 |
+
|
24 |
+
|
25 |
+
class LyraSDXLPipelineBase(TextualInversionLoaderMixin):
|
26 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, num_channels_unet=4, num_channels_latents=4, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
|
27 |
+
self.device = device
|
28 |
+
self.dtype = dtype
|
29 |
+
|
30 |
+
self.num_channels_unet = num_channels_unet
|
31 |
+
self.num_channels_latents = num_channels_latents
|
32 |
+
self.vae_scale_factor = vae_scale_factor
|
33 |
+
self.vae_scaling_factor = vae_scaling_factor
|
34 |
+
|
35 |
+
self.unet_cache = {}
|
36 |
+
self.unet_in_channels = 4
|
37 |
+
|
38 |
+
self.controlnet_cache = {}
|
39 |
+
|
40 |
+
self.loaded_lora = {}
|
41 |
+
self.loaded_lora_strength = {}
|
42 |
+
|
43 |
+
self.scheduler = None
|
44 |
+
|
45 |
+
self.init_pipe()
|
46 |
+
|
47 |
+
def init_pipe(self):
|
48 |
+
self.vae = LyraSdVaeModel(
|
49 |
+
scale_factor=self.vae_scale_factor, scaling_factor=self.vae_scaling_factor)
|
50 |
+
|
51 |
+
self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
|
52 |
+
3,
|
53 |
+
"fp16",
|
54 |
+
self.num_channels_unet,
|
55 |
+
self.num_channels_latents
|
56 |
+
)
|
57 |
+
|
58 |
+
self.image_processor = VaeImageProcessor(
|
59 |
+
vae_scale_factor=self.vae_scale_factor)
|
60 |
+
|
61 |
+
self.mask_processor = VaeImageProcessor(
|
62 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
63 |
+
)
|
64 |
+
|
65 |
+
self.feature_extractor = CLIPImageProcessor()
|
66 |
+
|
67 |
+
def reload_pipe(self, model_path):
|
68 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
69 |
+
model_path, subfolder="tokenizer")
|
70 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
71 |
+
model_path, subfolder="text_encoder").to(self.dtype).to(self.device)
|
72 |
+
|
73 |
+
self.reload_unet_model_v2(model_path)
|
74 |
+
self.reload_vae_model_v2(model_path)
|
75 |
+
|
76 |
+
if not self.scheduler:
|
77 |
+
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
|
78 |
+
model_path, subfolder="scheduler")
|
79 |
+
|
80 |
+
@property
|
81 |
+
def _execution_device(self):
|
82 |
+
if not hasattr(self.unet, "_hf_hook"):
|
83 |
+
return self.device
|
84 |
+
for module in self.unet.modules():
|
85 |
+
if (
|
86 |
+
hasattr(module, "_hf_hook")
|
87 |
+
and hasattr(module._hf_hook, "execution_device")
|
88 |
+
and module._hf_hook.execution_device is not None
|
89 |
+
):
|
90 |
+
return torch.device(module._hf_hook.execution_device)
|
91 |
+
return self.device
|
92 |
+
|
93 |
+
def reload_unet_model(self, unet_path, unet_file_format='fp32'):
|
94 |
+
if len(unet_path) > 0 and unet_path[-1] != "/":
|
95 |
+
unet_path = unet_path + "/"
|
96 |
+
self.unet.reload_unet_model(unet_path, unet_file_format)
|
97 |
+
self.load_embedding_weight(
|
98 |
+
self.add_embedding, f"{unet_path}add_embedding*", unet_file_format=unet_file_format)
|
99 |
+
|
100 |
+
def reload_vae_model(self, vae_path, vae_file_format='fp32'):
|
101 |
+
if len(vae_path) > 0 and vae_path[-1] != "/":
|
102 |
+
vae_path = vae_path + "/"
|
103 |
+
return self.vae.reload_vae_model(vae_path, vae_file_format)
|
104 |
+
|
105 |
+
def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
|
106 |
+
if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
|
107 |
+
lora_model_path = lora_model_path + "/"
|
108 |
+
lora = add_xltext_lora_layer(
|
109 |
+
self.text_encoder, self.text_encoder_2, lora_model_path, lora_strength, lora_file_format)
|
110 |
+
|
111 |
+
self.loaded_lora[lora_name] = lora
|
112 |
+
self.unet.load_lora(lora_model_path, lora_name,
|
113 |
+
lora_strength, lora_file_format)
|
114 |
+
|
115 |
+
def unload_lora(self, lora_name, clean_cache=False):
|
116 |
+
for layer_data in self.loaded_lora[lora_name]:
|
117 |
+
layer = layer_data['layer']
|
118 |
+
added_weight = layer_data['added_weight']
|
119 |
+
layer.weight.data -= added_weight
|
120 |
+
self.unet.unload_lora(lora_name, clean_cache)
|
121 |
+
del self.loaded_lora[lora_name]
|
122 |
+
gc.collect()
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
|
125 |
+
def load_lora_v2(self, lora_model_path, lora_name, lora_strength):
|
126 |
+
if lora_name in self.loaded_lora:
|
127 |
+
state_dict = self.loaded_lora[lora_name]
|
128 |
+
else:
|
129 |
+
state_dict = load_state_dict(lora_model_path)
|
130 |
+
self.loaded_lora[lora_name] = state_dict
|
131 |
+
self.loaded_lora_strength[lora_name] = lora_strength
|
132 |
+
add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
|
133 |
+
None, lora_strength)
|
134 |
+
|
135 |
+
def unload_lora_v2(self, lora_name, clean_cache=False):
|
136 |
+
state_dict = self.loaded_lora[lora_name]
|
137 |
+
lora_strength = self.loaded_lora_strength[lora_name]
|
138 |
+
add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
|
139 |
+
None, -1.0 * lora_strength)
|
140 |
+
del self.loaded_lora_strength[lora_name]
|
141 |
+
|
142 |
+
if clean_cache:
|
143 |
+
del self.loaded_lora[lora_name]
|
144 |
+
gc.collect()
|
145 |
+
torch.cuda.empty_cache()
|
146 |
+
|
147 |
+
def clean_lora_cache(self):
|
148 |
+
self.unet.clean_lora_cache()
|
149 |
+
|
150 |
+
def get_loaded_lora(self):
|
151 |
+
return self.unet.get_loaded_lora()
|
152 |
+
|
153 |
+
def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True):
|
154 |
+
self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path,
|
155 |
+
num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim)
|
156 |
+
|
157 |
+
def reload_unet_model_v2(self, model_path):
|
158 |
+
checkpoint_file = os.path.join(
|
159 |
+
model_path, "unet/diffusion_pytorch_model.bin")
|
160 |
+
if not os.path.exists(checkpoint_file):
|
161 |
+
checkpoint_file = os.path.join(
|
162 |
+
model_path, "unet/diffusion_pytorch_model.safetensors")
|
163 |
+
if checkpoint_file in self.unet_cache:
|
164 |
+
state_dict = self.unet_cache[checkpoint_file]
|
165 |
+
else:
|
166 |
+
if "safetensors" in checkpoint_file:
|
167 |
+
state_dict = load_file(checkpoint_file)
|
168 |
+
else:
|
169 |
+
state_dict = torch.load(checkpoint_file, map_location="cpu")
|
170 |
+
|
171 |
+
for key in state_dict:
|
172 |
+
if len(state_dict[key].shape) == 4:
|
173 |
+
# converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
|
174 |
+
state_dict[key] = state_dict[key].to(
|
175 |
+
torch.float16).permute(0, 2, 3, 1).contiguous()
|
176 |
+
state_dict[key] = state_dict[key].to(torch.float16)
|
177 |
+
self.unet_cache[checkpoint_file] = state_dict
|
178 |
+
|
179 |
+
self.unet.reload_unet_model_from_cache(state_dict, "cpu")
|
180 |
+
|
181 |
+
def reload_vae_model_v2(self, model_path):
|
182 |
+
self.vae.reload_vae_model_v2(model_path)
|
183 |
+
|
184 |
+
def load_controlnet_model(self, model_name, controlnet_path, model_dtype="fp32"):
|
185 |
+
if len(controlnet_path) > 0 and controlnet_path[-1] != "/":
|
186 |
+
controlnet_path = controlnet_path + "/"
|
187 |
+
self.unet.load_controlnet_model(model_name, controlnet_path, model_dtype)
|
188 |
+
|
189 |
+
def unload_controlnet_model(self, model_name):
|
190 |
+
self.unet.unload_controlnet_model(model_name, True)
|
191 |
+
|
192 |
+
def get_loaded_controlnet(self):
|
193 |
+
return self.unet.get_loaded_controlnet()
|
194 |
+
|
195 |
+
def load_controlnet_model_v2(self, model_name, controlnet_path):
|
196 |
+
checkpoint_file = os.path.join(controlnet_path, "diffusion_pytorch_model.bin")
|
197 |
+
if not os.path.exists(checkpoint_file):
|
198 |
+
checkpoint_file = os.path.join(controlnet_path, "diffusion_pytorch_model.safetensors")
|
199 |
+
if checkpoint_file in self.controlnet_cache:
|
200 |
+
state_dict = self.controlnet_cache[checkpoint_file]
|
201 |
+
else:
|
202 |
+
if "safetensors" in checkpoint_file:
|
203 |
+
state_dict = load_file(checkpoint_file)
|
204 |
+
else:
|
205 |
+
state_dict = torch.load(checkpoint_file, map_location="cpu")
|
206 |
+
|
207 |
+
for key in state_dict:
|
208 |
+
if len(state_dict[key].shape) == 4:
|
209 |
+
# converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
|
210 |
+
state_dict[key] = state_dict[key].to(torch.float16).permute(0,2,3,1).contiguous()
|
211 |
+
state_dict[key] = state_dict[key].to(torch.float16)
|
212 |
+
self.controlnet_cache[checkpoint_file] = state_dict
|
213 |
+
|
214 |
+
self.unet.load_controlnet_model_from_state_dict(model_name, state_dict, "cpu")
|
lyrasd_model/lyrasd_txt2img_inpaint_pipeline.py
ADDED
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
6 |
+
import GPUtil
|
7 |
+
import torch
|
8 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
9 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
10 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
11 |
+
from diffusers.utils.torch_utils import logging, randn_tensor
|
12 |
+
from PIL import Image
|
13 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
14 |
+
import gc
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
from .lyrasd_vae_model import LyraSdVaeModel
|
18 |
+
|
19 |
+
from diffusers.models.embeddings import ImageProjection
|
20 |
+
from transformers import (
|
21 |
+
CLIPImageProcessor,
|
22 |
+
CLIPVisionModelWithProjection,
|
23 |
+
)
|
24 |
+
|
25 |
+
from .lyrasd_pipeline_base import LyraSDXLPipelineBase
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28 |
+
|
29 |
+
|
30 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
31 |
+
"""
|
32 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
33 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
34 |
+
"""
|
35 |
+
std_text = noise_pred_text.std(
|
36 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
37 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
38 |
+
# rescale the results from guidance (fixes overexposure)
|
39 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
40 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
41 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + \
|
42 |
+
(1 - guidance_rescale) * noise_cfg
|
43 |
+
return noise_cfg
|
44 |
+
|
45 |
+
|
46 |
+
def numpy_to_pil(images):
|
47 |
+
"""
|
48 |
+
Convert a numpy image or a batch of images to a PIL image.
|
49 |
+
"""
|
50 |
+
if images.ndim == 3:
|
51 |
+
images = images[None, ...]
|
52 |
+
images = (images * 255).round().astype("uint8")
|
53 |
+
if images.shape[-1] == 1:
|
54 |
+
# special case for grayscale (single channel) images
|
55 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L")
|
56 |
+
for image in images]
|
57 |
+
else:
|
58 |
+
pil_images = [Image.fromarray(image) for image in images]
|
59 |
+
|
60 |
+
return pil_images
|
61 |
+
|
62 |
+
|
63 |
+
def retrieve_timesteps(
|
64 |
+
scheduler,
|
65 |
+
num_inference_steps: Optional[int] = None,
|
66 |
+
device: Optional[Union[str, torch.device]] = None,
|
67 |
+
timesteps: Optional[List[int]] = None,
|
68 |
+
**kwargs,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
72 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
scheduler (`SchedulerMixin`):
|
76 |
+
The scheduler to get timesteps from.
|
77 |
+
num_inference_steps (`int`):
|
78 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
79 |
+
`timesteps` must be `None`.
|
80 |
+
device (`str` or `torch.device`, *optional*):
|
81 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
82 |
+
timesteps (`List[int]`, *optional*):
|
83 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
84 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
85 |
+
must be `None`.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
89 |
+
second element is the number of inference steps.
|
90 |
+
"""
|
91 |
+
if timesteps is not None:
|
92 |
+
print("set(inspect.signature(scheduler.set_timesteps).parameters.keys())", set(
|
93 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()))
|
94 |
+
accepts_timesteps = "timesteps" in set(
|
95 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys())
|
96 |
+
if not accepts_timesteps:
|
97 |
+
raise ValueError(
|
98 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
99 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
100 |
+
)
|
101 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
102 |
+
timesteps = scheduler.timesteps
|
103 |
+
num_inference_steps = len(timesteps)
|
104 |
+
else:
|
105 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
106 |
+
timesteps = scheduler.timesteps
|
107 |
+
return timesteps, num_inference_steps
|
108 |
+
|
109 |
+
|
110 |
+
def retrieve_latents(
|
111 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
112 |
+
):
|
113 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
114 |
+
return encoder_output.latent_dist.sample(generator)
|
115 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
116 |
+
return encoder_output.latent_dist.mode()
|
117 |
+
elif hasattr(encoder_output, "latents"):
|
118 |
+
return encoder_output.latents
|
119 |
+
else:
|
120 |
+
raise AttributeError(
|
121 |
+
"Could not access latents of provided encoder_output")
|
122 |
+
|
123 |
+
|
124 |
+
class LyraSdTxt2ImgInpaintPipeline(LyraSDXLPipelineBase):
|
125 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215, num_channels_unet=9, num_channels_latents=4) -> None:
|
126 |
+
super().__init__(device, dtype, num_channels_unet=num_channels_unet, num_channels_latents=num_channels_latents,
|
127 |
+
vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
|
128 |
+
|
129 |
+
def _encode_prompt(
|
130 |
+
self,
|
131 |
+
prompt,
|
132 |
+
device,
|
133 |
+
num_images_per_prompt,
|
134 |
+
do_classifier_free_guidance,
|
135 |
+
negative_prompt=None,
|
136 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
137 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
138 |
+
):
|
139 |
+
r"""
|
140 |
+
Encodes the prompt into text encoder hidden states.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
prompt (`str` or `List[str]`, *optional*):
|
144 |
+
prompt to be encoded
|
145 |
+
device: (`torch.device`):
|
146 |
+
torch device
|
147 |
+
num_images_per_prompt (`int`):
|
148 |
+
number of images that should be generated per prompt
|
149 |
+
do_classifier_free_guidance (`bool`):
|
150 |
+
whether to use classifier free guidance or not
|
151 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
152 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
153 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
154 |
+
less than `1`).
|
155 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
156 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
157 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
158 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
159 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
160 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
161 |
+
argument.
|
162 |
+
"""
|
163 |
+
if prompt is not None and isinstance(prompt, str):
|
164 |
+
batch_size = 1
|
165 |
+
elif prompt is not None and isinstance(prompt, list):
|
166 |
+
batch_size = len(prompt)
|
167 |
+
else:
|
168 |
+
batch_size = prompt_embeds.shape[0]
|
169 |
+
|
170 |
+
if prompt_embeds is None:
|
171 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
172 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
173 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
174 |
+
|
175 |
+
text_inputs = self.tokenizer(
|
176 |
+
prompt,
|
177 |
+
padding="max_length",
|
178 |
+
max_length=self.tokenizer.model_max_length,
|
179 |
+
truncation=True,
|
180 |
+
return_tensors="pt",
|
181 |
+
)
|
182 |
+
text_input_ids = text_inputs.input_ids
|
183 |
+
untruncated_ids = self.tokenizer(
|
184 |
+
prompt, padding="longest", return_tensors="pt").input_ids
|
185 |
+
|
186 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
187 |
+
text_input_ids, untruncated_ids
|
188 |
+
):
|
189 |
+
removed_text = self.tokenizer.batch_decode(
|
190 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
|
191 |
+
)
|
192 |
+
logger.warning(
|
193 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
194 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
195 |
+
)
|
196 |
+
|
197 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
198 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
199 |
+
else:
|
200 |
+
attention_mask = None
|
201 |
+
|
202 |
+
prompt_embeds = self.text_encoder(
|
203 |
+
text_input_ids.to(device),
|
204 |
+
attention_mask=attention_mask,
|
205 |
+
)
|
206 |
+
prompt_embeds = prompt_embeds[0]
|
207 |
+
|
208 |
+
prompt_embeds = prompt_embeds.to(
|
209 |
+
dtype=self.text_encoder.dtype, device=device)
|
210 |
+
|
211 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
212 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
213 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
214 |
+
prompt_embeds = prompt_embeds.view(
|
215 |
+
bs_embed * num_images_per_prompt, seq_len, -1)
|
216 |
+
|
217 |
+
# get unconditional embeddings for classifier free guidance
|
218 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
219 |
+
uncond_tokens: List[str]
|
220 |
+
if negative_prompt is None:
|
221 |
+
uncond_tokens = [""] * batch_size
|
222 |
+
elif type(prompt) is not type(negative_prompt):
|
223 |
+
raise TypeError(
|
224 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
225 |
+
f" {type(prompt)}."
|
226 |
+
)
|
227 |
+
elif isinstance(negative_prompt, str):
|
228 |
+
uncond_tokens = [negative_prompt]
|
229 |
+
elif batch_size != len(negative_prompt):
|
230 |
+
raise ValueError(
|
231 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
232 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
233 |
+
" the batch size of `prompt`."
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
uncond_tokens = negative_prompt
|
237 |
+
|
238 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
239 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
240 |
+
uncond_tokens = self.maybe_convert_prompt(
|
241 |
+
uncond_tokens, self.tokenizer)
|
242 |
+
|
243 |
+
max_length = prompt_embeds.shape[1]
|
244 |
+
uncond_input = self.tokenizer(
|
245 |
+
uncond_tokens,
|
246 |
+
padding="max_length",
|
247 |
+
max_length=max_length,
|
248 |
+
truncation=True,
|
249 |
+
return_tensors="pt",
|
250 |
+
)
|
251 |
+
|
252 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
253 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
254 |
+
else:
|
255 |
+
attention_mask = None
|
256 |
+
|
257 |
+
negative_prompt_embeds = self.text_encoder(
|
258 |
+
uncond_input.input_ids.to(device),
|
259 |
+
attention_mask=attention_mask,
|
260 |
+
)
|
261 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
262 |
+
|
263 |
+
if do_classifier_free_guidance:
|
264 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
265 |
+
seq_len = negative_prompt_embeds.shape[1]
|
266 |
+
|
267 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
268 |
+
dtype=self.text_encoder.dtype, device=device)
|
269 |
+
|
270 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
271 |
+
1, num_images_per_prompt, 1)
|
272 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
273 |
+
batch_size * num_images_per_prompt, seq_len, -1)
|
274 |
+
|
275 |
+
# For classifier free guidance, we need to do two forward passes.
|
276 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
277 |
+
# to avoid doing two forward passes
|
278 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
279 |
+
|
280 |
+
return prompt_embeds
|
281 |
+
|
282 |
+
def load_ip_adapter(self,
|
283 |
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
284 |
+
subfolder: str,
|
285 |
+
weight_name: str,
|
286 |
+
**kwargs
|
287 |
+
):
|
288 |
+
# if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
289 |
+
self.feature_extractor = CLIPImageProcessor()
|
290 |
+
|
291 |
+
# if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
292 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
293 |
+
pretrained_model_name_or_path_or_dict,
|
294 |
+
subfolder=os.path.join(subfolder, "image_encoder"),
|
295 |
+
).to(self.device, dtype=self.dtype)
|
296 |
+
# else:
|
297 |
+
# print("kio: already has image_encoder", hasattr(self, "image_encoder"), getattr(self, "feature_extractor", None) is None)
|
298 |
+
|
299 |
+
# kiotodo: init ImageProjection
|
300 |
+
model_path = os.path.join(
|
301 |
+
pretrained_model_name_or_path_or_dict, subfolder, weight_name)
|
302 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
303 |
+
|
304 |
+
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
|
305 |
+
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
|
306 |
+
self.encoder_hid_proj = ImageProjection(
|
307 |
+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
|
308 |
+
)
|
309 |
+
|
310 |
+
image_proj_state_dict = {}
|
311 |
+
image_proj_state_dict.update(
|
312 |
+
{
|
313 |
+
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
|
314 |
+
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
|
315 |
+
"norm.weight": state_dict["image_proj"]["norm.weight"],
|
316 |
+
"norm.bias": state_dict["image_proj"]["norm.bias"],
|
317 |
+
}
|
318 |
+
)
|
319 |
+
|
320 |
+
self.encoder_hid_proj.load_state_dict(image_proj_state_dict)
|
321 |
+
self.encoder_hid_proj.to(dtype=self.dtype, device=self.device)
|
322 |
+
|
323 |
+
dir_ipadapter = os.path.join(
|
324 |
+
pretrained_model_name_or_path_or_dict, subfolder, '.'.join(weight_name.split(".")[:-1]))
|
325 |
+
self.unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
|
326 |
+
|
327 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
328 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
329 |
+
if not isinstance(image, torch.Tensor):
|
330 |
+
image = self.feature_extractor(
|
331 |
+
image, return_tensors="pt").pixel_values
|
332 |
+
|
333 |
+
image = image.to(device=device, dtype=dtype)
|
334 |
+
image_embeds = self.image_encoder(image).image_embeds
|
335 |
+
image_embeds = image_embeds.repeat_interleave(
|
336 |
+
num_images_per_prompt, dim=0)
|
337 |
+
|
338 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
339 |
+
return image_embeds, uncond_image_embeds
|
340 |
+
|
341 |
+
def decode_latents(self, latents):
|
342 |
+
latents = 1 / self.vae.scaling_factor * latents
|
343 |
+
image = self.vae.decode(latents).sample
|
344 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
345 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
346 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
347 |
+
return image
|
348 |
+
|
349 |
+
def lyra_decode_latents(self, latents):
|
350 |
+
# print("lyra_decode_latents")
|
351 |
+
# np.save("", latents.)
|
352 |
+
# np.save(f"/workspace/vae_model/latent.npy", latents.detach().cpu().numpy())
|
353 |
+
latents = 1 / self.vae.scaling_factor * latents
|
354 |
+
# latents = latents.permute(0, 2, 3, 1).contiguous()
|
355 |
+
image = self.vae.decode(latents)
|
356 |
+
image = image.permute(0, 2, 3, 1)
|
357 |
+
# print(image)
|
358 |
+
# GPUtil.showUtilization(all=True)
|
359 |
+
|
360 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
361 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
362 |
+
image = image.cpu().float().numpy()
|
363 |
+
|
364 |
+
return image
|
365 |
+
|
366 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
367 |
+
# get the original timestep using init_timestep
|
368 |
+
init_timestep = min(
|
369 |
+
int(num_inference_steps * strength), num_inference_steps)
|
370 |
+
|
371 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
372 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
|
373 |
+
|
374 |
+
return timesteps, num_inference_steps - t_start
|
375 |
+
|
376 |
+
def check_inputs(
|
377 |
+
self,
|
378 |
+
prompt,
|
379 |
+
height,
|
380 |
+
width,
|
381 |
+
negative_prompt=None,
|
382 |
+
prompt_embeds=None,
|
383 |
+
negative_prompt_embeds=None,
|
384 |
+
):
|
385 |
+
if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
|
386 |
+
raise ValueError(
|
387 |
+
f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
|
388 |
+
|
389 |
+
if prompt is not None and prompt_embeds is not None:
|
390 |
+
raise ValueError(
|
391 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
392 |
+
" only forward one of the two."
|
393 |
+
)
|
394 |
+
elif prompt is None and prompt_embeds is None:
|
395 |
+
raise ValueError(
|
396 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
397 |
+
)
|
398 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
399 |
+
raise ValueError(
|
400 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
401 |
+
|
402 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
403 |
+
raise ValueError(
|
404 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
405 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
406 |
+
)
|
407 |
+
|
408 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
409 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
410 |
+
raise ValueError(
|
411 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
412 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
413 |
+
f" {negative_prompt_embeds.shape}."
|
414 |
+
)
|
415 |
+
|
416 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
417 |
+
if isinstance(generator, list):
|
418 |
+
image_latents = [
|
419 |
+
retrieve_latents(AutoencoderKLOutput(
|
420 |
+
latent_dist=self.vae.encode(image[i: i + 1])), generator=generator[i])
|
421 |
+
for i in range(image.shape[0])
|
422 |
+
]
|
423 |
+
image_latents = torch.cat(image_latents, dim=0)
|
424 |
+
else:
|
425 |
+
image_latents = retrieve_latents(AutoencoderKLOutput(
|
426 |
+
latent_dist=self.vae.encode(image)), generator=generator)
|
427 |
+
|
428 |
+
image_latents = self.vae_scaling_factor * image_latents
|
429 |
+
|
430 |
+
return image_latents
|
431 |
+
|
432 |
+
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None,
|
433 |
+
image=None, timestep=None, is_strength_max=True, return_noise=False, return_image_latents=False):
|
434 |
+
shape = (batch_size, num_channels_latents, height //
|
435 |
+
self.vae_scale_factor, width // self.vae_scale_factor)
|
436 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
437 |
+
raise ValueError(
|
438 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
439 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
440 |
+
)
|
441 |
+
|
442 |
+
if (image is None or timestep is None) and not is_strength_max:
|
443 |
+
raise ValueError(
|
444 |
+
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
|
445 |
+
"However, either the image or the noise timestep has not been provided."
|
446 |
+
)
|
447 |
+
|
448 |
+
if return_image_latents or (latents is None and not is_strength_max):
|
449 |
+
image = image.to(device=device, dtype=dtype)
|
450 |
+
|
451 |
+
if image.shape[1] == 4:
|
452 |
+
image_latents = image
|
453 |
+
else:
|
454 |
+
image_latents = self._encode_vae_image(
|
455 |
+
image=image, generator=generator)
|
456 |
+
image_latents = image_latents.repeat(
|
457 |
+
batch_size // image_latents.shape[0], 1, 1, 1)
|
458 |
+
|
459 |
+
if latents is None:
|
460 |
+
noise = randn_tensor(shape, generator=generator,
|
461 |
+
device=device, dtype=dtype)
|
462 |
+
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
463 |
+
latents = noise if is_strength_max else self.scheduler.add_noise(
|
464 |
+
image_latents, noise, timestep)
|
465 |
+
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
466 |
+
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
467 |
+
else:
|
468 |
+
noise = latents.to(device)
|
469 |
+
latents = noise * self.scheduler.init_noise_sigma
|
470 |
+
|
471 |
+
outputs = (latents,)
|
472 |
+
|
473 |
+
if return_noise:
|
474 |
+
outputs += (noise,)
|
475 |
+
|
476 |
+
if return_image_latents:
|
477 |
+
outputs += (image_latents,)
|
478 |
+
|
479 |
+
return outputs
|
480 |
+
|
481 |
+
def prepare_mask_latents(
|
482 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
483 |
+
):
|
484 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
485 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
486 |
+
# and half precision
|
487 |
+
mask = torch.nn.functional.interpolate(
|
488 |
+
mask, size=(height // self.vae_scale_factor,
|
489 |
+
width // self.vae_scale_factor)
|
490 |
+
)
|
491 |
+
mask = mask.to(device=device, dtype=dtype)
|
492 |
+
|
493 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
494 |
+
|
495 |
+
if masked_image.shape[1] == 4:
|
496 |
+
masked_image_latents = masked_image
|
497 |
+
else:
|
498 |
+
masked_image_latents = self._encode_vae_image(
|
499 |
+
masked_image, generator=generator)
|
500 |
+
|
501 |
+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
502 |
+
if mask.shape[0] < batch_size:
|
503 |
+
if not batch_size % mask.shape[0] == 0:
|
504 |
+
raise ValueError(
|
505 |
+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
506 |
+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
507 |
+
" of masks that you pass is divisible by the total requested batch size."
|
508 |
+
)
|
509 |
+
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
510 |
+
if masked_image_latents.shape[0] < batch_size:
|
511 |
+
if not batch_size % masked_image_latents.shape[0] == 0:
|
512 |
+
raise ValueError(
|
513 |
+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
514 |
+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
515 |
+
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
516 |
+
)
|
517 |
+
masked_image_latents = masked_image_latents.repeat(
|
518 |
+
batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
519 |
+
|
520 |
+
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
521 |
+
masked_image_latents = (
|
522 |
+
torch.cat([masked_image_latents] *
|
523 |
+
2) if do_classifier_free_guidance else masked_image_latents
|
524 |
+
)
|
525 |
+
|
526 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
527 |
+
masked_image_latents = masked_image_latents.to(
|
528 |
+
device=device, dtype=dtype)
|
529 |
+
return mask, masked_image_latents
|
530 |
+
|
531 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
532 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
533 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
534 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
535 |
+
# and should be between [0, 1]
|
536 |
+
|
537 |
+
accepts_eta = "eta" in set(inspect.signature(
|
538 |
+
self.scheduler.step).parameters.keys())
|
539 |
+
extra_step_kwargs = {}
|
540 |
+
if accepts_eta:
|
541 |
+
extra_step_kwargs["eta"] = eta
|
542 |
+
|
543 |
+
# check if the scheduler accepts generator
|
544 |
+
accepts_generator = "generator" in set(
|
545 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
546 |
+
if accepts_generator:
|
547 |
+
extra_step_kwargs["generator"] = generator
|
548 |
+
return extra_step_kwargs
|
549 |
+
|
550 |
+
@torch.no_grad()
|
551 |
+
def __call__(
|
552 |
+
self,
|
553 |
+
prompt: Union[str, List[str]] = None,
|
554 |
+
image: PipelineImageInput = None,
|
555 |
+
mask_image: PipelineImageInput = None,
|
556 |
+
masked_image_latents: torch.FloatTensor = None,
|
557 |
+
height: Optional[int] = None,
|
558 |
+
width: Optional[int] = None,
|
559 |
+
strength: float = 1.0,
|
560 |
+
num_inference_steps: int = 50,
|
561 |
+
guidance_scale: float = 7.5,
|
562 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
563 |
+
num_images_per_prompt: Optional[int] = 1,
|
564 |
+
eta: float = 0.0,
|
565 |
+
generator: Optional[Union[torch.Generator,
|
566 |
+
List[torch.Generator]]] = None,
|
567 |
+
latents: Optional[torch.FloatTensor] = None,
|
568 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
569 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
570 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
571 |
+
param_scale_dict: Optional[dict] = {}
|
572 |
+
):
|
573 |
+
r"""
|
574 |
+
Function invoked when calling the pipeline for generation.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
prompt (`str` or `List[str]`, *optional*):
|
578 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
579 |
+
instead.
|
580 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
581 |
+
The height in pixels of the generated image.
|
582 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
583 |
+
The width in pixels of the generated image.
|
584 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
585 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
586 |
+
expense of slower inference.
|
587 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
588 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
589 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
590 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
591 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
592 |
+
usually at the expense of lower image quality.
|
593 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
594 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
595 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
596 |
+
less than `1`).
|
597 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
598 |
+
The number of images to generate per prompt.
|
599 |
+
eta (`float`, *optional*, defaults to 0.0):
|
600 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
601 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
602 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
603 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
604 |
+
to make generation deterministic.
|
605 |
+
latents (`torch.FloatTensor`, *optional*):
|
606 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
607 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
608 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
609 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
610 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
611 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
612 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
613 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
614 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
615 |
+
argument.
|
616 |
+
|
617 |
+
"""
|
618 |
+
# 0. Default height and width to unet
|
619 |
+
height = height or self.unet_config_sample_size * self.vae_scale_factor
|
620 |
+
width = width or self.unet_config_sample_size * self.vae_scale_factor
|
621 |
+
# self.unet_config.sample_size = 64
|
622 |
+
# height = 512
|
623 |
+
# width = 512
|
624 |
+
|
625 |
+
# 1. Check inputs. Raise error if not correct
|
626 |
+
# self.check_inputs(
|
627 |
+
# prompt, height, width, negative_prompt, prompt_embeds, negative_prompt_embeds
|
628 |
+
# )
|
629 |
+
|
630 |
+
# 2. Define call parameters
|
631 |
+
if prompt is not None and isinstance(prompt, str):
|
632 |
+
batch_size = 1
|
633 |
+
elif prompt is not None and isinstance(prompt, list):
|
634 |
+
batch_size = len(prompt)
|
635 |
+
else:
|
636 |
+
batch_size = prompt_embeds.shape[0]
|
637 |
+
|
638 |
+
device = self.device
|
639 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
640 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
641 |
+
# corresponds to doing no classifier free guidance.
|
642 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
643 |
+
|
644 |
+
# 3. Encode input prompt
|
645 |
+
prompt_embeds = self._encode_prompt(
|
646 |
+
prompt,
|
647 |
+
device,
|
648 |
+
num_images_per_prompt,
|
649 |
+
do_classifier_free_guidance,
|
650 |
+
negative_prompt,
|
651 |
+
prompt_embeds=prompt_embeds,
|
652 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
653 |
+
)
|
654 |
+
|
655 |
+
# 3.5 Encode ipadapter_image
|
656 |
+
if ip_adapter_image is not None:
|
657 |
+
image_embeds, negative_image_embeds = self.encode_image(
|
658 |
+
ip_adapter_image, device, num_images_per_prompt)
|
659 |
+
if do_classifier_free_guidance:
|
660 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
661 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(self.dtype)
|
662 |
+
|
663 |
+
# 4. Prepare timesteps
|
664 |
+
# self.scheduler.set_timesteps(num_inference_steps, device=device)
|
665 |
+
# timesteps = self.scheduler.timesteps
|
666 |
+
|
667 |
+
# 4.5 Prepare mask and image
|
668 |
+
timesteps = None
|
669 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
670 |
+
self.scheduler, num_inference_steps, device, timesteps)
|
671 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
672 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
673 |
+
)
|
674 |
+
# check that number of inference steps is not < 1 - as this doesn't make sense
|
675 |
+
if num_inference_steps < 1:
|
676 |
+
raise ValueError(
|
677 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
678 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
679 |
+
)
|
680 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
681 |
+
latent_timestep = timesteps[:1].repeat(
|
682 |
+
batch_size * num_images_per_prompt)
|
683 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
684 |
+
is_strength_max = strength == 1.0
|
685 |
+
|
686 |
+
# 5. Preprocess mask and image
|
687 |
+
|
688 |
+
init_image = self.image_processor.preprocess(
|
689 |
+
image, height=height, width=width)
|
690 |
+
init_image = init_image.to(dtype=torch.float32)
|
691 |
+
|
692 |
+
# 5. Prepare latent variables
|
693 |
+
return_image_latents = self.num_channels_unet == 4
|
694 |
+
latents_outputs = self.prepare_latents(
|
695 |
+
batch_size * num_images_per_prompt,
|
696 |
+
self.num_channels_latents,
|
697 |
+
height,
|
698 |
+
width,
|
699 |
+
prompt_embeds.dtype,
|
700 |
+
device,
|
701 |
+
generator,
|
702 |
+
latents,
|
703 |
+
image=init_image,
|
704 |
+
timestep=latent_timestep,
|
705 |
+
is_strength_max=is_strength_max,
|
706 |
+
return_noise=True,
|
707 |
+
return_image_latents=return_image_latents
|
708 |
+
)
|
709 |
+
|
710 |
+
if return_image_latents:
|
711 |
+
latents, noise, image_latents = latents_outputs
|
712 |
+
else:
|
713 |
+
latents, noise = latents_outputs
|
714 |
+
|
715 |
+
# 5.5 Prepare mask latent variables
|
716 |
+
mask_condition = self.mask_processor.preprocess(
|
717 |
+
mask_image, height=height, width=width)
|
718 |
+
if masked_image_latents is None:
|
719 |
+
masked_image = init_image * (mask_condition < 0.5)
|
720 |
+
else:
|
721 |
+
masked_image = masked_image_latents
|
722 |
+
|
723 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
724 |
+
mask_condition,
|
725 |
+
masked_image,
|
726 |
+
batch_size * num_images_per_prompt,
|
727 |
+
height,
|
728 |
+
width,
|
729 |
+
prompt_embeds.dtype,
|
730 |
+
device,
|
731 |
+
generator,
|
732 |
+
do_classifier_free_guidance,
|
733 |
+
)
|
734 |
+
|
735 |
+
# Check that sizes of mask, masked image and latents match
|
736 |
+
if self.num_channels_unet == 9:
|
737 |
+
# default case for runwayml/stable-diffusion-inpainting
|
738 |
+
num_channels_mask = mask.shape[1]
|
739 |
+
num_channels_masked_image = masked_image_latents.shape[1]
|
740 |
+
if self.num_channels_latents + num_channels_mask + num_channels_masked_image != self.num_channels_unet:
|
741 |
+
raise ValueError(
|
742 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
743 |
+
f" {self.num_channels_latents} but received `num_channels_latents`: {self.num_channels_latents} +"
|
744 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
745 |
+
f" = {self.num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
746 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
747 |
+
)
|
748 |
+
elif self.num_channels_unet != 4:
|
749 |
+
raise ValueError(
|
750 |
+
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
751 |
+
)
|
752 |
+
|
753 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
754 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
755 |
+
|
756 |
+
# 7. Denoising loop
|
757 |
+
num_warmup_steps = len(timesteps) - \
|
758 |
+
num_inference_steps * self.scheduler.order
|
759 |
+
|
760 |
+
for i, t in enumerate(timesteps):
|
761 |
+
# expand the latents if we are doing classifier free guidance
|
762 |
+
latent_model_input = torch.cat(
|
763 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
764 |
+
latent_model_input = self.scheduler.scale_model_input(
|
765 |
+
latent_model_input, t)
|
766 |
+
|
767 |
+
if self.num_channels_unet == 9:
|
768 |
+
latent_model_input = torch.cat(
|
769 |
+
[latent_model_input, mask, masked_image_latents], dim=1)
|
770 |
+
|
771 |
+
latent_model_input = latent_model_input.permute(
|
772 |
+
0, 2, 3, 1).contiguous()
|
773 |
+
|
774 |
+
# latent_model_input = latent_model_input[:,:4,:,:].
|
775 |
+
|
776 |
+
# 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
|
777 |
+
# todo: forward ip image_embeds
|
778 |
+
# break
|
779 |
+
if ip_adapter_image is not None:
|
780 |
+
noise_pred = self.unet.forward(
|
781 |
+
latent_model_input, prompt_embeds, t, None, None, None, None, {"ip_hidden_states": image_embeds}, param_scale_dict)
|
782 |
+
else:
|
783 |
+
noise_pred = self.unet.forward(
|
784 |
+
latent_model_input, prompt_embeds, t)
|
785 |
+
|
786 |
+
noise_pred = noise_pred.permute(0, 3, 1, 2).contiguous()
|
787 |
+
# saver.save_v(f"latent_model_input_{i}", latent_model_input)
|
788 |
+
# saver.save_v(f"noise_pred_{i}", noise_pred)
|
789 |
+
# saver.save_v(f"prompt_embeds_{i}", prompt_embeds)
|
790 |
+
|
791 |
+
# perform guidance
|
792 |
+
if do_classifier_free_guidance:
|
793 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
794 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
795 |
+
(noise_pred_text - noise_pred_uncond)
|
796 |
+
|
797 |
+
# compute the previous noisy sample x_t -> x_t-1
|
798 |
+
latents = self.scheduler.step(
|
799 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
800 |
+
if self.num_channels_unet == 4:
|
801 |
+
init_latents_proper = image_latents
|
802 |
+
if self.do_classifier_free_guidance:
|
803 |
+
init_mask, _ = mask.chunk(2)
|
804 |
+
else:
|
805 |
+
init_mask = mask
|
806 |
+
|
807 |
+
if i < len(timesteps) - 1:
|
808 |
+
noise_timestep = timesteps[i + 1]
|
809 |
+
init_latents_proper = self.scheduler.add_noise(
|
810 |
+
init_latents_proper, noise, torch.tensor(
|
811 |
+
[noise_timestep])
|
812 |
+
)
|
813 |
+
|
814 |
+
latents = (1 - init_mask) * init_latents_proper + \
|
815 |
+
init_mask * latents
|
816 |
+
|
817 |
+
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
818 |
+
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
819 |
+
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
820 |
+
# # compute the previous noisy sample x_t -> x_t-1
|
821 |
+
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
822 |
+
# image = self.decode_latents(latents)
|
823 |
+
image = self.lyra_decode_latents(latents)
|
824 |
+
image = numpy_to_pil(image)
|
825 |
+
|
826 |
+
return image
|
lyrasd_model/lyrasd_txt2img_pipeline.py
CHANGED
@@ -2,7 +2,7 @@ import inspect
|
|
2 |
import os
|
3 |
import time
|
4 |
from typing import Any, Callable, Dict, List, Optional, Union
|
5 |
-
|
6 |
import torch
|
7 |
from diffusers.loaders import TextualInversionLoaderMixin
|
8 |
from diffusers.models import AutoencoderKL
|
@@ -10,17 +10,43 @@ from diffusers.schedulers import (DPMSolverMultistepScheduler,
|
|
10 |
EulerAncestralDiscreteScheduler,
|
11 |
EulerDiscreteScheduler,
|
12 |
KarrasDiffusionSchedulers)
|
13 |
-
from diffusers.utils import logging, randn_tensor
|
14 |
from PIL import Image
|
15 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
16 |
import gc
|
17 |
import numpy as np
|
18 |
|
19 |
-
from .
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def numpy_to_pil(images):
|
25 |
"""
|
26 |
Convert a numpy image or a batch of images to a PIL image.
|
@@ -30,68 +56,18 @@ def numpy_to_pil(images):
|
|
30 |
images = (images * 255).round().astype("uint8")
|
31 |
if images.shape[-1] == 1:
|
32 |
# special case for grayscale (single channel) images
|
33 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L")
|
|
|
34 |
else:
|
35 |
pil_images = [Image.fromarray(image) for image in images]
|
36 |
|
37 |
return pil_images
|
38 |
|
39 |
|
40 |
-
class LyraSdTxt2ImgPipeline(
|
41 |
-
def __init__(self,
|
42 |
-
|
43 |
-
self.dtype = dtype
|
44 |
-
|
45 |
-
torch.classes.load_library(lib_so_path)
|
46 |
-
|
47 |
-
self.vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(dtype).to(device)
|
48 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
49 |
-
self.text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(dtype).to(device)
|
50 |
-
unet_path = os.path.join(model_path, "unet_bins/")
|
51 |
-
|
52 |
-
self.unet_in_channels = 4
|
53 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
54 |
-
self.vae.enable_tiling()
|
55 |
-
self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp(
|
56 |
-
3, # max num of controlnets
|
57 |
-
"fp16" # inference dtype (can only use fp16 for now)
|
58 |
-
)
|
59 |
-
|
60 |
-
unet_path = os.path.join(model_path, "unet_bins/")
|
61 |
|
62 |
-
self.reload_unet_model(unet_path, model_dtype)
|
63 |
-
|
64 |
-
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
|
65 |
-
|
66 |
-
self.loaded_lora = {}
|
67 |
-
|
68 |
-
def reload_unet_model(self, unet_path, unet_file_format='fp32'):
|
69 |
-
if len(unet_path) > 0 and unet_path[-1] != "/":
|
70 |
-
unet_path = unet_path + "/"
|
71 |
-
return self.unet.reload_unet_model(unet_path, unet_file_format)
|
72 |
-
|
73 |
-
def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
|
74 |
-
if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
|
75 |
-
lora_model_path = lora_model_path + "/"
|
76 |
-
lora = add_text_lora_layer(self.text_encoder, lora_model_path, lora_strength, lora_file_format)
|
77 |
-
self.loaded_lora[lora_name] = lora
|
78 |
-
self.unet.load_lora(lora_model_path, lora_name, lora_strength, lora_file_format)
|
79 |
-
|
80 |
-
def unload_lora(self, lora_name, clean_cache=False):
|
81 |
-
for layer_data in self.loaded_lora[lora_name]:
|
82 |
-
layer = layer_data['layer']
|
83 |
-
added_weight = layer_data['added_weight']
|
84 |
-
layer.weight.data -= added_weight
|
85 |
-
self.unet.unload_lora(lora_name, clean_cache)
|
86 |
-
del self.loaded_lora[lora_name]
|
87 |
-
gc.collect()
|
88 |
-
torch.cuda.empty_cache()
|
89 |
-
|
90 |
-
def clean_lora_cache(self):
|
91 |
-
self.unet.clean_lora_cache()
|
92 |
-
|
93 |
-
def get_loaded_lora(self):
|
94 |
-
return self.unet.get_loaded_lora()
|
95 |
|
96 |
def _encode_prompt(
|
97 |
self,
|
@@ -147,7 +123,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
147 |
return_tensors="pt",
|
148 |
)
|
149 |
text_input_ids = text_inputs.input_ids
|
150 |
-
untruncated_ids = self.tokenizer(
|
|
|
151 |
|
152 |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
153 |
text_input_ids, untruncated_ids
|
@@ -171,12 +148,14 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
171 |
)
|
172 |
prompt_embeds = prompt_embeds[0]
|
173 |
|
174 |
-
prompt_embeds = prompt_embeds.to(
|
|
|
175 |
|
176 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
177 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
178 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
179 |
-
prompt_embeds = prompt_embeds.view(
|
|
|
180 |
|
181 |
# get unconditional embeddings for classifier free guidance
|
182 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
@@ -201,7 +180,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
201 |
|
202 |
# textual inversion: procecss multi-vector tokens if necessary
|
203 |
if isinstance(self, TextualInversionLoaderMixin):
|
204 |
-
uncond_tokens = self.maybe_convert_prompt(
|
|
|
205 |
|
206 |
max_length = prompt_embeds.shape[1]
|
207 |
uncond_input = self.tokenizer(
|
@@ -227,10 +207,13 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
227 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
228 |
seq_len = negative_prompt_embeds.shape[1]
|
229 |
|
230 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
|
|
231 |
|
232 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
233 |
-
|
|
|
|
|
234 |
|
235 |
# For classifier free guidance, we need to do two forward passes.
|
236 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
@@ -239,14 +222,83 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
239 |
|
240 |
return prompt_embeds
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
def decode_latents(self, latents):
|
243 |
-
latents = 1 / self.vae.
|
244 |
image = self.vae.decode(latents).sample
|
245 |
image = (image / 2 + 0.5).clamp(0, 1)
|
246 |
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
247 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
248 |
return image
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
def check_inputs(
|
251 |
self,
|
252 |
prompt,
|
@@ -257,7 +309,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
257 |
negative_prompt_embeds=None,
|
258 |
):
|
259 |
if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
|
260 |
-
raise ValueError(
|
|
|
261 |
|
262 |
if prompt is not None and prompt_embeds is not None:
|
263 |
raise ValueError(
|
@@ -269,7 +322,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
269 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
270 |
)
|
271 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
272 |
-
raise ValueError(
|
|
|
273 |
|
274 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
275 |
raise ValueError(
|
@@ -286,7 +340,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
286 |
)
|
287 |
|
288 |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
289 |
-
shape = (batch_size, num_channels_latents, height //
|
|
|
290 |
if isinstance(generator, list) and len(generator) != batch_size:
|
291 |
raise ValueError(
|
292 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
@@ -294,7 +349,8 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
294 |
)
|
295 |
|
296 |
if latents is None:
|
297 |
-
latents = randn_tensor(
|
|
|
298 |
else:
|
299 |
latents = latents.to(device)
|
300 |
|
@@ -308,13 +364,15 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
308 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
309 |
# and should be between [0, 1]
|
310 |
|
311 |
-
accepts_eta = "eta" in set(inspect.signature(
|
|
|
312 |
extra_step_kwargs = {}
|
313 |
if accepts_eta:
|
314 |
extra_step_kwargs["eta"] = eta
|
315 |
|
316 |
# check if the scheduler accepts generator
|
317 |
-
accepts_generator = "generator" in set(
|
|
|
318 |
if accepts_generator:
|
319 |
extra_step_kwargs["generator"] = generator
|
320 |
return extra_step_kwargs
|
@@ -330,10 +388,13 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
330 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
331 |
num_images_per_prompt: Optional[int] = 1,
|
332 |
eta: float = 0.0,
|
333 |
-
generator: Optional[Union[torch.Generator,
|
|
|
334 |
latents: Optional[torch.FloatTensor] = None,
|
335 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
336 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
337 |
):
|
338 |
r"""
|
339 |
Function invoked when calling the pipeline for generation.
|
@@ -410,6 +471,14 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
410 |
negative_prompt_embeds=negative_prompt_embeds,
|
411 |
)
|
412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
# 4. Prepare timesteps
|
414 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
415 |
timesteps = self.scheduler.timesteps
|
@@ -431,28 +500,46 @@ class LyraSdTxt2ImgPipeline(TextualInversionLoaderMixin):
|
|
431 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
432 |
|
433 |
# 7. Denoising loop
|
434 |
-
num_warmup_steps = len(timesteps) -
|
|
|
435 |
|
436 |
for i, t in enumerate(timesteps):
|
437 |
# expand the latents if we are doing classifier free guidance
|
438 |
-
latent_model_input = torch.cat(
|
439 |
-
|
440 |
-
latent_model_input =
|
441 |
-
|
442 |
-
|
443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
446 |
-
# perform guidance
|
447 |
|
|
|
|
|
|
|
448 |
if do_classifier_free_guidance:
|
449 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
450 |
-
noise_pred = noise_pred_uncond + guidance_scale *
|
|
|
451 |
|
|
|
|
|
|
|
452 |
# compute the previous noisy sample x_t -> x_t-1
|
453 |
-
latents = self.scheduler.step(
|
454 |
-
|
455 |
-
image = self.decode_latents(latents)
|
|
|
456 |
image = numpy_to_pil(image)
|
457 |
|
458 |
return image
|
|
|
2 |
import os
|
3 |
import time
|
4 |
from typing import Any, Callable, Dict, List, Optional, Union
|
5 |
+
import GPUtil
|
6 |
import torch
|
7 |
from diffusers.loaders import TextualInversionLoaderMixin
|
8 |
from diffusers.models import AutoencoderKL
|
|
|
10 |
EulerAncestralDiscreteScheduler,
|
11 |
EulerDiscreteScheduler,
|
12 |
KarrasDiffusionSchedulers)
|
13 |
+
from diffusers.utils.torch_utils import logging, randn_tensor
|
14 |
from PIL import Image
|
15 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
16 |
import gc
|
17 |
import numpy as np
|
18 |
|
19 |
+
from .lyrasd_vae_model import LyraSdVaeModel
|
20 |
+
|
21 |
+
from diffusers.image_processor import PipelineImageInput
|
22 |
+
from diffusers.models.embeddings import ImageProjection
|
23 |
+
from transformers import (
|
24 |
+
CLIPImageProcessor,
|
25 |
+
CLIPVisionModelWithProjection,
|
26 |
+
)
|
27 |
+
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
|
28 |
+
from safetensors.torch import load_file
|
29 |
+
from .lyrasd_pipeline_base import LyraSDXLPipelineBase
|
30 |
|
31 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
|
33 |
|
34 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
35 |
+
"""
|
36 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
37 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
38 |
+
"""
|
39 |
+
std_text = noise_pred_text.std(
|
40 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
41 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
42 |
+
# rescale the results from guidance (fixes overexposure)
|
43 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
44 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
45 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + \
|
46 |
+
(1 - guidance_rescale) * noise_cfg
|
47 |
+
return noise_cfg
|
48 |
+
|
49 |
+
|
50 |
def numpy_to_pil(images):
|
51 |
"""
|
52 |
Convert a numpy image or a batch of images to a PIL image.
|
|
|
56 |
images = (images * 255).round().astype("uint8")
|
57 |
if images.shape[-1] == 1:
|
58 |
# special case for grayscale (single channel) images
|
59 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L")
|
60 |
+
for image in images]
|
61 |
else:
|
62 |
pil_images = [Image.fromarray(image) for image in images]
|
63 |
|
64 |
return pil_images
|
65 |
|
66 |
|
67 |
+
class LyraSdTxt2ImgPipeline(LyraSDXLPipelineBase):
|
68 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None:
|
69 |
+
super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def _encode_prompt(
|
73 |
self,
|
|
|
123 |
return_tensors="pt",
|
124 |
)
|
125 |
text_input_ids = text_inputs.input_ids
|
126 |
+
untruncated_ids = self.tokenizer(
|
127 |
+
prompt, padding="longest", return_tensors="pt").input_ids
|
128 |
|
129 |
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
130 |
text_input_ids, untruncated_ids
|
|
|
148 |
)
|
149 |
prompt_embeds = prompt_embeds[0]
|
150 |
|
151 |
+
prompt_embeds = prompt_embeds.to(
|
152 |
+
dtype=self.text_encoder.dtype, device=device)
|
153 |
|
154 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
155 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
156 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
157 |
+
prompt_embeds = prompt_embeds.view(
|
158 |
+
bs_embed * num_images_per_prompt, seq_len, -1)
|
159 |
|
160 |
# get unconditional embeddings for classifier free guidance
|
161 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
|
180 |
|
181 |
# textual inversion: procecss multi-vector tokens if necessary
|
182 |
if isinstance(self, TextualInversionLoaderMixin):
|
183 |
+
uncond_tokens = self.maybe_convert_prompt(
|
184 |
+
uncond_tokens, self.tokenizer)
|
185 |
|
186 |
max_length = prompt_embeds.shape[1]
|
187 |
uncond_input = self.tokenizer(
|
|
|
207 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
208 |
seq_len = negative_prompt_embeds.shape[1]
|
209 |
|
210 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
211 |
+
dtype=self.text_encoder.dtype, device=device)
|
212 |
|
213 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
214 |
+
1, num_images_per_prompt, 1)
|
215 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
216 |
+
batch_size * num_images_per_prompt, seq_len, -1)
|
217 |
|
218 |
# For classifier free guidance, we need to do two forward passes.
|
219 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
|
222 |
|
223 |
return prompt_embeds
|
224 |
|
225 |
+
def load_ip_adapter(self,
|
226 |
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
227 |
+
subfolder: str,
|
228 |
+
weight_name: str,
|
229 |
+
**kwargs
|
230 |
+
):
|
231 |
+
# if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
232 |
+
self.feature_extractor = CLIPImageProcessor()
|
233 |
+
|
234 |
+
# if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
235 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
236 |
+
pretrained_model_name_or_path_or_dict,
|
237 |
+
subfolder=os.path.join(subfolder, "image_encoder"),
|
238 |
+
).to(self.device, dtype=self.dtype)
|
239 |
+
# else:
|
240 |
+
# print("kio: already has image_encoder", hasattr(self, "image_encoder"), getattr(self, "feature_extractor", None) is None)
|
241 |
+
|
242 |
+
# kiotodo: init ImageProjection
|
243 |
+
model_path = os.path.join(
|
244 |
+
pretrained_model_name_or_path_or_dict, subfolder, weight_name)
|
245 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
246 |
+
|
247 |
+
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
|
248 |
+
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
|
249 |
+
self.encoder_hid_proj = ImageProjection(
|
250 |
+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
|
251 |
+
)
|
252 |
+
|
253 |
+
image_proj_state_dict = {}
|
254 |
+
image_proj_state_dict.update(
|
255 |
+
{
|
256 |
+
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
|
257 |
+
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
|
258 |
+
"norm.weight": state_dict["image_proj"]["norm.weight"],
|
259 |
+
"norm.bias": state_dict["image_proj"]["norm.bias"],
|
260 |
+
}
|
261 |
+
)
|
262 |
+
|
263 |
+
self.encoder_hid_proj.load_state_dict(image_proj_state_dict)
|
264 |
+
self.encoder_hid_proj.to(dtype=self.dtype, device=self.device)
|
265 |
+
|
266 |
+
dir_ipadapter = os.path.join(
|
267 |
+
pretrained_model_name_or_path_or_dict, subfolder, '.'.join(weight_name.split(".")[:-1]))
|
268 |
+
self.unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
|
269 |
+
|
270 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
271 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
272 |
+
if not isinstance(image, torch.Tensor):
|
273 |
+
image = self.feature_extractor(
|
274 |
+
image, return_tensors="pt").pixel_values
|
275 |
+
|
276 |
+
image = image.to(device=device, dtype=dtype)
|
277 |
+
image_embeds = self.image_encoder(image).image_embeds
|
278 |
+
image_embeds = image_embeds.repeat_interleave(
|
279 |
+
num_images_per_prompt, dim=0)
|
280 |
+
|
281 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
282 |
+
return image_embeds, uncond_image_embeds
|
283 |
+
|
284 |
def decode_latents(self, latents):
|
285 |
+
latents = 1 / self.vae.scaling_factor * latents
|
286 |
image = self.vae.decode(latents).sample
|
287 |
image = (image / 2 + 0.5).clamp(0, 1)
|
288 |
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
289 |
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
290 |
return image
|
291 |
|
292 |
+
def lyra_decode_latents(self, latents):
|
293 |
+
latents = 1 / self.vae.scaling_factor * latents
|
294 |
+
image = self.vae.decode(latents)
|
295 |
+
image = image.permute(0, 2, 3, 1)
|
296 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
297 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
298 |
+
image = image.cpu().float().numpy()
|
299 |
+
|
300 |
+
return image
|
301 |
+
|
302 |
def check_inputs(
|
303 |
self,
|
304 |
prompt,
|
|
|
309 |
negative_prompt_embeds=None,
|
310 |
):
|
311 |
if height % 64 != 0 or width % 64 != 0: # 初版暂时只支持 64 的倍数的 height 和 width
|
312 |
+
raise ValueError(
|
313 |
+
f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
|
314 |
|
315 |
if prompt is not None and prompt_embeds is not None:
|
316 |
raise ValueError(
|
|
|
322 |
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
323 |
)
|
324 |
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
325 |
+
raise ValueError(
|
326 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
327 |
|
328 |
if negative_prompt is not None and negative_prompt_embeds is not None:
|
329 |
raise ValueError(
|
|
|
340 |
)
|
341 |
|
342 |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
343 |
+
shape = (batch_size, num_channels_latents, height //
|
344 |
+
self.vae.scale_factor, width // self.vae.scale_factor)
|
345 |
if isinstance(generator, list) and len(generator) != batch_size:
|
346 |
raise ValueError(
|
347 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
|
|
349 |
)
|
350 |
|
351 |
if latents is None:
|
352 |
+
latents = randn_tensor(
|
353 |
+
shape, generator=generator, device=device, dtype=dtype)
|
354 |
else:
|
355 |
latents = latents.to(device)
|
356 |
|
|
|
364 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
365 |
# and should be between [0, 1]
|
366 |
|
367 |
+
accepts_eta = "eta" in set(inspect.signature(
|
368 |
+
self.scheduler.step).parameters.keys())
|
369 |
extra_step_kwargs = {}
|
370 |
if accepts_eta:
|
371 |
extra_step_kwargs["eta"] = eta
|
372 |
|
373 |
# check if the scheduler accepts generator
|
374 |
+
accepts_generator = "generator" in set(
|
375 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
376 |
if accepts_generator:
|
377 |
extra_step_kwargs["generator"] = generator
|
378 |
return extra_step_kwargs
|
|
|
388 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
389 |
num_images_per_prompt: Optional[int] = 1,
|
390 |
eta: float = 0.0,
|
391 |
+
generator: Optional[Union[torch.Generator,
|
392 |
+
List[torch.Generator]]] = None,
|
393 |
latents: Optional[torch.FloatTensor] = None,
|
394 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
395 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
396 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
397 |
+
param_scale_dict: Optional[dict] = {}
|
398 |
):
|
399 |
r"""
|
400 |
Function invoked when calling the pipeline for generation.
|
|
|
471 |
negative_prompt_embeds=negative_prompt_embeds,
|
472 |
)
|
473 |
|
474 |
+
# 3.5 Encode ipadapter_image
|
475 |
+
if ip_adapter_image is not None:
|
476 |
+
image_embeds, negative_image_embeds = self.encode_image(
|
477 |
+
ip_adapter_image, device, num_images_per_prompt)
|
478 |
+
if do_classifier_free_guidance:
|
479 |
+
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
480 |
+
image_embeds = self.encoder_hid_proj(image_embeds).to(self.dtype)
|
481 |
+
|
482 |
# 4. Prepare timesteps
|
483 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
484 |
timesteps = self.scheduler.timesteps
|
|
|
500 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
501 |
|
502 |
# 7. Denoising loop
|
503 |
+
num_warmup_steps = len(timesteps) - \
|
504 |
+
num_inference_steps * self.scheduler.order
|
505 |
|
506 |
for i, t in enumerate(timesteps):
|
507 |
# expand the latents if we are doing classifier free guidance
|
508 |
+
latent_model_input = torch.cat(
|
509 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
510 |
+
latent_model_input = self.scheduler.scale_model_input(
|
511 |
+
latent_model_input, t)
|
512 |
+
latent_model_input = latent_model_input.permute(
|
513 |
+
0, 2, 3, 1).contiguous()
|
514 |
+
|
515 |
+
# 后边三个 None 是给到controlnet 的参数,暂时给到 None 当 placeholder
|
516 |
+
# todo: forward ip image_embeds
|
517 |
+
# break
|
518 |
+
if ip_adapter_image is not None:
|
519 |
+
noise_pred = self.unet.forward(
|
520 |
+
latent_model_input, prompt_embeds, t, None, None, None, None, {"ip_hidden_states": image_embeds}, param_scale_dict)
|
521 |
+
else:
|
522 |
+
noise_pred = self.unet.forward(
|
523 |
+
latent_model_input, prompt_embeds, t)
|
524 |
|
525 |
noise_pred = noise_pred.permute(0, 3, 1, 2)
|
|
|
526 |
|
527 |
+
np.save(f"/workspace/noise_pred_{i}.npy", noise_pred.detach().cpu().numpy())
|
528 |
+
|
529 |
+
# perform guidance
|
530 |
if do_classifier_free_guidance:
|
531 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
532 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
533 |
+
(noise_pred_text - noise_pred_uncond)
|
534 |
|
535 |
+
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
536 |
+
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
537 |
+
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
538 |
# compute the previous noisy sample x_t -> x_t-1
|
539 |
+
latents = self.scheduler.step(
|
540 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
541 |
+
# image = self.decode_latents(latents)
|
542 |
+
image = self.lyra_decode_latents(latents)
|
543 |
image = numpy_to_pil(image)
|
544 |
|
545 |
return image
|
lyrasd_model/lyrasd_vae_model.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Dict, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from safetensors.torch import load_file
|
23 |
+
|
24 |
+
import os
|
25 |
+
|
26 |
+
class LyraSdVaeModel():
|
27 |
+
r"""
|
28 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
29 |
+
|
30 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
31 |
+
for all models (such as downloading or saving).
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
35 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
36 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
37 |
+
Tuple of downsample block types.
|
38 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
39 |
+
Tuple of upsample block types.
|
40 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
41 |
+
Tuple of block output channels.
|
42 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
43 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
44 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
45 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
46 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
47 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
48 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
49 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
50 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
51 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
52 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
53 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
54 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
55 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
56 |
+
"""
|
57 |
+
|
58 |
+
_supports_gradient_checkpointing = True
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
dtype: str = "fp16",
|
63 |
+
scaling_factor: float = 0.18215,
|
64 |
+
scale_factor: int = 8,
|
65 |
+
is_upcast: bool = False
|
66 |
+
):
|
67 |
+
super().__init__()
|
68 |
+
self.is_upcast = is_upcast
|
69 |
+
self.scaling_factor = scaling_factor
|
70 |
+
self.scale_factor = scale_factor
|
71 |
+
self.model = torch.classes.lyrasd.VaeModelOp(
|
72 |
+
dtype,
|
73 |
+
is_upcast
|
74 |
+
)
|
75 |
+
|
76 |
+
self.vae_cache = {}
|
77 |
+
|
78 |
+
self.use_slicing = False
|
79 |
+
self.use_tiling = False
|
80 |
+
|
81 |
+
self.tile_latent_min_size = 512
|
82 |
+
self.tile_sample_min_size = 64
|
83 |
+
self.tile_overlap_factor = 0.25
|
84 |
+
|
85 |
+
def reload_vae_model(self, vae_path, vae_file_format='fp32'):
|
86 |
+
if len(vae_path) > 0 and vae_path[-1] != "/":
|
87 |
+
vae_path = vae_path + "/"
|
88 |
+
return self.model.reload_vae_model(vae_path, vae_file_format)
|
89 |
+
|
90 |
+
def reload_vae_model_v2(self, model_path):
|
91 |
+
checkpoint_file = os.path.join(model_path, "vae/diffusion_pytorch_model.bin")
|
92 |
+
if not os.path.exists(checkpoint_file):
|
93 |
+
checkpoint_file = os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors")
|
94 |
+
if checkpoint_file in self.vae_cache:
|
95 |
+
state_dict = self.vae_cache[checkpoint_file]
|
96 |
+
else:
|
97 |
+
if "safetensors" in checkpoint_file:
|
98 |
+
state_dict = load_file(checkpoint_file)
|
99 |
+
else:
|
100 |
+
state_dict = torch.load(checkpoint_file, map_location="cpu")
|
101 |
+
|
102 |
+
# replace deprecated weights
|
103 |
+
for path in ["encoder.mid_block.attentions.0", "decoder.mid_block.attentions.0"]:
|
104 |
+
# group_norm path stays the same
|
105 |
+
|
106 |
+
# query -> to_q
|
107 |
+
if f"{path}.query.weight" in state_dict:
|
108 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
109 |
+
if f"{path}.query.bias" in state_dict:
|
110 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
111 |
+
|
112 |
+
# key -> to_k
|
113 |
+
if f"{path}.key.weight" in state_dict:
|
114 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
115 |
+
if f"{path}.key.bias" in state_dict:
|
116 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
117 |
+
|
118 |
+
# value -> to_v
|
119 |
+
if f"{path}.value.weight" in state_dict:
|
120 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
121 |
+
if f"{path}.value.bias" in state_dict:
|
122 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
123 |
+
|
124 |
+
# proj_attn -> to_out.0
|
125 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
126 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
127 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
128 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
129 |
+
|
130 |
+
for key in state_dict:
|
131 |
+
# print(key)
|
132 |
+
if len(state_dict[key].shape) == 4:
|
133 |
+
state_dict[key] = state_dict[key].permute(0,2,3,1).contiguous()
|
134 |
+
else:
|
135 |
+
state_dict[key] = state_dict[key]
|
136 |
+
if self.is_upcast and (key.startswith("decoder.up_blocks.2") or key.startswith("decoder.up_blocks.3") or key.startswith("decoder.conv_norm_out")):
|
137 |
+
# print(key)
|
138 |
+
state_dict[key] = state_dict[key].to(torch.float32)
|
139 |
+
else:
|
140 |
+
state_dict[key] = state_dict[key].to(torch.float16)
|
141 |
+
|
142 |
+
self.vae_cache[checkpoint_file] = state_dict
|
143 |
+
|
144 |
+
return self.model.reload_vae_model_from_cache(state_dict, "cpu")
|
145 |
+
|
146 |
+
def enable_tiling(self, use_tiling: bool = True):
|
147 |
+
r"""
|
148 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
149 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
150 |
+
processing larger images.
|
151 |
+
"""
|
152 |
+
self.use_tiling = use_tiling
|
153 |
+
|
154 |
+
def disable_tiling(self):
|
155 |
+
r"""
|
156 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
157 |
+
decoding in one step.
|
158 |
+
"""
|
159 |
+
self.enable_tiling(False)
|
160 |
+
|
161 |
+
def enable_slicing(self):
|
162 |
+
r"""
|
163 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
164 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
165 |
+
"""
|
166 |
+
self.use_slicing = True
|
167 |
+
|
168 |
+
def disable_slicing(self):
|
169 |
+
r"""
|
170 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
171 |
+
decoding in one step.
|
172 |
+
"""
|
173 |
+
self.use_slicing = False
|
174 |
+
|
175 |
+
def lyra_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
176 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
177 |
+
x = self.model.vae_decode(x)
|
178 |
+
return x.permute(0, 3, 1, 2)
|
179 |
+
|
180 |
+
def lyra_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
181 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
182 |
+
x = self.model.vae_encode(x)
|
183 |
+
return x.permute(0, 3, 1, 2)
|
184 |
+
|
185 |
+
def encode(
|
186 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
187 |
+
) -> DiagonalGaussianDistribution:
|
188 |
+
"""
|
189 |
+
Encode a batch of images into latents.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
x (`torch.FloatTensor`): Input batch of images.
|
193 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
194 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
198 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
199 |
+
"""
|
200 |
+
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
201 |
+
return self.tiled_encode(x, return_dict=return_dict)
|
202 |
+
|
203 |
+
if self.use_slicing and x.shape[0] > 1:
|
204 |
+
encoded_slices = [self.lyra_encode(
|
205 |
+
x_slice) for x_slice in x.split(1)]
|
206 |
+
h = torch.cat(encoded_slices)
|
207 |
+
posterior = DiagonalGaussianDistribution(h)
|
208 |
+
else:
|
209 |
+
moments = self.lyra_encode(x)
|
210 |
+
posterior = DiagonalGaussianDistribution(moments)
|
211 |
+
|
212 |
+
return posterior
|
213 |
+
|
214 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
215 |
+
if self.use_tiling and (z.shape[2] > self.tile_latent_min_size or z.shape[3] > self.tile_latent_min_size):
|
216 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
217 |
+
|
218 |
+
dec = self.lyra_decode(z)
|
219 |
+
|
220 |
+
return dec
|
221 |
+
|
222 |
+
def decode(
|
223 |
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
224 |
+
) -> torch.FloatTensor:
|
225 |
+
"""
|
226 |
+
Decode a batch of images.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
230 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
231 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
235 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
236 |
+
returned.
|
237 |
+
|
238 |
+
"""
|
239 |
+
if self.use_slicing and z.shape[0] > 1:
|
240 |
+
decoded_slices = [self._decode(
|
241 |
+
z_slice) for z_slice in z.split(1)]
|
242 |
+
decoded = torch.cat(decoded_slices)
|
243 |
+
else:
|
244 |
+
decoded = self._decode(z)
|
245 |
+
|
246 |
+
return decoded
|
247 |
+
|
248 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
249 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
250 |
+
for y in range(blend_extent):
|
251 |
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * \
|
252 |
+
(1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
253 |
+
return b
|
254 |
+
|
255 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
256 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
257 |
+
for x in range(blend_extent):
|
258 |
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * \
|
259 |
+
(1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
260 |
+
return b
|
261 |
+
|
262 |
+
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> DiagonalGaussianDistribution:
|
263 |
+
r"""Encode a batch of images using a tiled encoder.
|
264 |
+
|
265 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
266 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
267 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
268 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
269 |
+
output, but they should be much less noticeable.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
x (`torch.FloatTensor`): Input batch of images.
|
273 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
274 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
278 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
279 |
+
`tuple` is returned.
|
280 |
+
"""
|
281 |
+
overlap_size = int(self.tile_sample_min_size *
|
282 |
+
(1 - self.tile_overlap_factor))
|
283 |
+
blend_extent = int(self.tile_latent_min_size *
|
284 |
+
self.tile_overlap_factor)
|
285 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
286 |
+
|
287 |
+
# Split the image into 512x512 tiles and encode them separately.
|
288 |
+
rows = []
|
289 |
+
for i in range(0, x.shape[2], overlap_size):
|
290 |
+
row = []
|
291 |
+
for j in range(0, x.shape[3], overlap_size):
|
292 |
+
tile = x[:, :, i: i + self.tile_sample_min_size,
|
293 |
+
j: j + self.tile_sample_min_size]
|
294 |
+
tile = self.lyra_encode(tile)
|
295 |
+
row.append(tile)
|
296 |
+
rows.append(row)
|
297 |
+
result_rows = []
|
298 |
+
for i, row in enumerate(rows):
|
299 |
+
result_row = []
|
300 |
+
for j, tile in enumerate(row):
|
301 |
+
# blend the above tile and the left tile
|
302 |
+
# to the current tile and add the current tile to the result row
|
303 |
+
if i > 0:
|
304 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
305 |
+
if j > 0:
|
306 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
307 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
308 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
309 |
+
|
310 |
+
moments = torch.cat(result_rows, dim=2)
|
311 |
+
posterior = DiagonalGaussianDistribution(moments)
|
312 |
+
|
313 |
+
return posterior
|
314 |
+
|
315 |
+
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
316 |
+
r"""
|
317 |
+
Decode a batch of images using a tiled decoder.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
321 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
322 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
323 |
+
|
324 |
+
Returns:
|
325 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
326 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
327 |
+
returned.
|
328 |
+
"""
|
329 |
+
overlap_size = int(self.tile_latent_min_size *
|
330 |
+
(1 - self.tile_overlap_factor))
|
331 |
+
blend_extent = int(self.tile_sample_min_size *
|
332 |
+
self.tile_overlap_factor)
|
333 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
334 |
+
|
335 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
336 |
+
# The tiles have an overlap to avoid seams between tiles.
|
337 |
+
rows = []
|
338 |
+
for i in range(0, z.shape[2], overlap_size):
|
339 |
+
row = []
|
340 |
+
for j in range(0, z.shape[3], overlap_size):
|
341 |
+
tile = z[:, :, i: i + self.tile_latent_min_size,
|
342 |
+
j: j + self.tile_latent_min_size]
|
343 |
+
decoded = self.lyra_decode(tile)
|
344 |
+
row.append(decoded)
|
345 |
+
rows.append(row)
|
346 |
+
result_rows = []
|
347 |
+
for i, row in enumerate(rows):
|
348 |
+
result_row = []
|
349 |
+
for j, tile in enumerate(row):
|
350 |
+
# blend the above tile and the left tile
|
351 |
+
# to the current tile and add the current tile to the result row
|
352 |
+
if i > 0:
|
353 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
354 |
+
if j > 0:
|
355 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
356 |
+
result_row.append(tile[:, :, :row_limit, :row_limit])
|
357 |
+
result_rows.append(torch.cat(result_row, dim=3))
|
358 |
+
|
359 |
+
dec = torch.cat(result_rows, dim=2)
|
360 |
+
if not return_dict:
|
361 |
+
return (dec,)
|
362 |
+
|
363 |
+
return dec
|
lyrasd_model/lyrasdxl_controlnet_txt2img_pipeline.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
import PIL
|
12 |
+
|
13 |
+
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
|
14 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
15 |
+
from diffusers.image_processor import VaeImageProcessor
|
16 |
+
from diffusers.models import AutoencoderKL
|
17 |
+
from diffusers.schedulers import (DPMSolverMultistepScheduler,
|
18 |
+
EulerAncestralDiscreteScheduler,
|
19 |
+
EulerDiscreteScheduler,
|
20 |
+
KarrasDiffusionSchedulers)
|
21 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
22 |
+
from diffusers.utils.torch_utils import randn_tensor
|
23 |
+
from diffusers.utils import logging
|
24 |
+
from PIL import Image
|
25 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
26 |
+
from diffusers.utils import PIL_INTERPOLATION
|
27 |
+
from .lyrasd_vae_model import LyraSdVaeModel
|
28 |
+
|
29 |
+
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
|
30 |
+
from safetensors.torch import load_file
|
31 |
+
from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase
|
32 |
+
|
33 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
34 |
+
"""
|
35 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
36 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
37 |
+
"""
|
38 |
+
std_text = noise_pred_text.std(
|
39 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
40 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
41 |
+
# rescale the results from guidance (fixes overexposure)
|
42 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
43 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
44 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + \
|
45 |
+
(1 - guidance_rescale) * noise_cfg
|
46 |
+
return noise_cfg
|
47 |
+
|
48 |
+
|
49 |
+
class LyraSdXLControlnetTxt2ImgPipeline(LyraSDXLPipelineBase, StableDiffusionXLPipeline):
|
50 |
+
device = torch.device("cpu")
|
51 |
+
dtype = torch.float32
|
52 |
+
|
53 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None:
|
54 |
+
self.register_to_config(force_zeros_for_empty_prompt=True)
|
55 |
+
|
56 |
+
super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
|
57 |
+
|
58 |
+
|
59 |
+
def prepare_image(
|
60 |
+
self,
|
61 |
+
image,
|
62 |
+
width,
|
63 |
+
height,
|
64 |
+
batch_size,
|
65 |
+
num_images_per_prompt,
|
66 |
+
device,
|
67 |
+
dtype,
|
68 |
+
do_classifier_free_guidance=False,
|
69 |
+
guess_mode=False,
|
70 |
+
):
|
71 |
+
image = self.control_image_processor.preprocess(image, height, width)
|
72 |
+
image = image.permute(0, 2, 3, 1)
|
73 |
+
|
74 |
+
image = image.to(device=device, dtype=dtype)
|
75 |
+
# print(image.shape)
|
76 |
+
# print(image)
|
77 |
+
|
78 |
+
return image
|
79 |
+
|
80 |
+
@property
|
81 |
+
def _execution_device(self):
|
82 |
+
if not hasattr(self.unet, "_hf_hook"):
|
83 |
+
return self.device
|
84 |
+
for module in self.unet.modules():
|
85 |
+
if (
|
86 |
+
hasattr(module, "_hf_hook")
|
87 |
+
and hasattr(module._hf_hook, "execution_device")
|
88 |
+
and module._hf_hook.execution_device is not None
|
89 |
+
):
|
90 |
+
return torch.device(module._hf_hook.execution_device)
|
91 |
+
return self.device
|
92 |
+
|
93 |
+
def _get_aug_emb(self, add_embedding, time_ids, text_embeds, dtype):
|
94 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
95 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
96 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
97 |
+
add_embeds = add_embeds.to(dtype)
|
98 |
+
aug_emb = add_embedding(add_embeds)
|
99 |
+
return aug_emb
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def __call__(
|
103 |
+
self,
|
104 |
+
prompt: Union[str, List[str]] = None,
|
105 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
106 |
+
height: Optional[int] = None,
|
107 |
+
width: Optional[int] = None,
|
108 |
+
num_inference_steps: int = 50,
|
109 |
+
denoising_end: Optional[float] = None,
|
110 |
+
guidance_scale: float = 5.0,
|
111 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
112 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
113 |
+
num_images_per_prompt: Optional[int] = 1,
|
114 |
+
controlnet_names: Optional[List[str]] = None,
|
115 |
+
controlnet_images: Optional[List[PIL.Image.Image]] = None,
|
116 |
+
controlnet_scale: Optional[List[float]] = None,
|
117 |
+
guess_mode=False,
|
118 |
+
eta: float = 0.0,
|
119 |
+
generator: Optional[Union[torch.Generator,
|
120 |
+
List[torch.Generator]]] = None,
|
121 |
+
latents: Optional[torch.FloatTensor] = None,
|
122 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
123 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
124 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
125 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
126 |
+
output_type: Optional[str] = "pil",
|
127 |
+
return_dict: bool = True,
|
128 |
+
callback: Optional[Callable[[
|
129 |
+
int, int, torch.FloatTensor], None]] = None,
|
130 |
+
callback_steps: int = 1,
|
131 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
132 |
+
guidance_rescale: float = 0.0,
|
133 |
+
original_size: Optional[Tuple[int, int]] = None,
|
134 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
135 |
+
target_size: Optional[Tuple[int, int]] = None,
|
136 |
+
):
|
137 |
+
|
138 |
+
# 0. Default height and width to unet
|
139 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
140 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
141 |
+
|
142 |
+
original_size = original_size or (height, width)
|
143 |
+
target_size = target_size or (height, width)
|
144 |
+
|
145 |
+
# 1. Check inputs. Raise error if not correct
|
146 |
+
self.check_inputs(
|
147 |
+
prompt,
|
148 |
+
prompt_2,
|
149 |
+
height,
|
150 |
+
width,
|
151 |
+
callback_steps,
|
152 |
+
negative_prompt,
|
153 |
+
negative_prompt_2,
|
154 |
+
prompt_embeds,
|
155 |
+
negative_prompt_embeds,
|
156 |
+
pooled_prompt_embeds,
|
157 |
+
negative_pooled_prompt_embeds,
|
158 |
+
)
|
159 |
+
|
160 |
+
# 2. Define call parameters
|
161 |
+
if prompt is not None and isinstance(prompt, str):
|
162 |
+
batch_size = 1
|
163 |
+
elif prompt is not None and isinstance(prompt, list):
|
164 |
+
batch_size = len(prompt)
|
165 |
+
else:
|
166 |
+
batch_size = prompt_embeds.shape[0]
|
167 |
+
|
168 |
+
device = self._execution_device
|
169 |
+
|
170 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
171 |
+
|
172 |
+
# 3. Encode input prompt
|
173 |
+
text_encoder_lora_scale = (
|
174 |
+
cross_attention_kwargs.get(
|
175 |
+
"scale", None) if cross_attention_kwargs is not None else None
|
176 |
+
)
|
177 |
+
(
|
178 |
+
prompt_embeds,
|
179 |
+
negative_prompt_embeds,
|
180 |
+
pooled_prompt_embeds,
|
181 |
+
negative_pooled_prompt_embeds,
|
182 |
+
) = self.encode_prompt(
|
183 |
+
prompt=prompt,
|
184 |
+
prompt_2=prompt_2,
|
185 |
+
device=device,
|
186 |
+
num_images_per_prompt=num_images_per_prompt,
|
187 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
188 |
+
negative_prompt=negative_prompt,
|
189 |
+
negative_prompt_2=negative_prompt_2,
|
190 |
+
prompt_embeds=prompt_embeds,
|
191 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
192 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
193 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
194 |
+
lora_scale=text_encoder_lora_scale,
|
195 |
+
)
|
196 |
+
|
197 |
+
control_images = []
|
198 |
+
|
199 |
+
for image_ in controlnet_images:
|
200 |
+
image_ = self.prepare_image(
|
201 |
+
image=image_,
|
202 |
+
width=width,
|
203 |
+
height=height,
|
204 |
+
batch_size=batch_size * num_images_per_prompt,
|
205 |
+
num_images_per_prompt=num_images_per_prompt,
|
206 |
+
device=device,
|
207 |
+
dtype=prompt_embeds.dtype,
|
208 |
+
do_classifier_free_guidance=do_classifier_free_guidance
|
209 |
+
)
|
210 |
+
|
211 |
+
control_images.append(image_)
|
212 |
+
|
213 |
+
control_scales = []
|
214 |
+
|
215 |
+
scales = [1.0, ] * 10
|
216 |
+
if guess_mode:
|
217 |
+
scales = torch.logspace(-1, 0, 10).tolist()
|
218 |
+
|
219 |
+
for scale in controlnet_scale:
|
220 |
+
scales_ = [d * scale for d in scales]
|
221 |
+
control_scales.append(scales_)
|
222 |
+
|
223 |
+
# 4. Prepare timesteps
|
224 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
225 |
+
|
226 |
+
timesteps = self.scheduler.timesteps
|
227 |
+
|
228 |
+
# 5. Prepare latent variables
|
229 |
+
num_channels_latents = self.unet_in_channels
|
230 |
+
latents = self.prepare_latents(
|
231 |
+
batch_size * num_images_per_prompt,
|
232 |
+
num_channels_latents,
|
233 |
+
height,
|
234 |
+
width,
|
235 |
+
prompt_embeds.dtype,
|
236 |
+
device,
|
237 |
+
generator,
|
238 |
+
latents,
|
239 |
+
)
|
240 |
+
|
241 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
242 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
243 |
+
|
244 |
+
# 7. Prepare added time ids & embeddings
|
245 |
+
add_text_embeds = pooled_prompt_embeds
|
246 |
+
add_time_ids = list(
|
247 |
+
original_size + crops_coords_top_left + target_size)
|
248 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
249 |
+
|
250 |
+
if do_classifier_free_guidance:
|
251 |
+
prompt_embeds = torch.cat(
|
252 |
+
[negative_prompt_embeds, prompt_embeds], dim=0)
|
253 |
+
add_text_embeds = torch.cat(
|
254 |
+
[negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
255 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
256 |
+
|
257 |
+
prompt_embeds = prompt_embeds.to(device)
|
258 |
+
add_text_embeds = add_text_embeds.to(device)
|
259 |
+
add_time_ids = add_time_ids.to(device).repeat(
|
260 |
+
batch_size * num_images_per_prompt, 1)
|
261 |
+
|
262 |
+
# 8. Denoising loop
|
263 |
+
num_warmup_steps = max(
|
264 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
265 |
+
|
266 |
+
# 7.1 Apply denoising_end
|
267 |
+
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
268 |
+
discrete_timestep_cutoff = int(
|
269 |
+
round(
|
270 |
+
self.scheduler.config.num_train_timesteps
|
271 |
+
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
272 |
+
)
|
273 |
+
)
|
274 |
+
num_inference_steps = len(
|
275 |
+
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
276 |
+
timesteps = timesteps[:num_inference_steps]
|
277 |
+
|
278 |
+
aug_emb = self._get_aug_emb(
|
279 |
+
self.add_embedding, add_time_ids, add_text_embeds, prompt_embeds.dtype)
|
280 |
+
|
281 |
+
controlnet_aug_embs = []
|
282 |
+
for controlnet_name in controlnet_names:
|
283 |
+
controlnet_aug_embs.append(self._get_aug_emb(self.controlnet_add_embedding[controlnet_name],
|
284 |
+
add_time_ids, add_text_embeds, prompt_embeds.dtype))
|
285 |
+
|
286 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
287 |
+
for i, t in enumerate(timesteps):
|
288 |
+
# expand the latents if we are doing classifier free guidance
|
289 |
+
latent_model_input = torch.cat(
|
290 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
291 |
+
|
292 |
+
latent_model_input = self.scheduler.scale_model_input(
|
293 |
+
latent_model_input, t)
|
294 |
+
latent_model_input = latent_model_input.permute(
|
295 |
+
0, 2, 3, 1).contiguous()
|
296 |
+
|
297 |
+
noise_pred = self.unet.forward(
|
298 |
+
latent_model_input, prompt_embeds, t, aug_emb,
|
299 |
+
controlnet_names, control_images, controlnet_aug_embs, control_scales, guess_mode).permute(0, 3, 1, 2)
|
300 |
+
|
301 |
+
# print(noise_pred)
|
302 |
+
|
303 |
+
# perform guidance
|
304 |
+
if do_classifier_free_guidance:
|
305 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
306 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
307 |
+
(noise_pred_text - noise_pred_uncond)
|
308 |
+
|
309 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
310 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
311 |
+
noise_pred = rescale_noise_cfg(
|
312 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
313 |
+
|
314 |
+
# compute the previous noisy sample x_t -> x_t-1
|
315 |
+
latents = self.scheduler.step(
|
316 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
317 |
+
|
318 |
+
# call the callback, if provided
|
319 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
320 |
+
progress_bar.update()
|
321 |
+
if callback is not None and i % callback_steps == 0:
|
322 |
+
callback(i, t, latents)
|
323 |
+
|
324 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
325 |
+
# if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
326 |
+
# self.upcast_vae()
|
327 |
+
# latents = latents.to(
|
328 |
+
# next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
329 |
+
# # latents = latents.to(torch.float32)
|
330 |
+
# if output_type == "latent":
|
331 |
+
# return latents
|
332 |
+
|
333 |
+
# np.save(f"/workspace/latents.npy", latents.detach().cpu().numpy())
|
334 |
+
|
335 |
+
# image = self.vae.decode(
|
336 |
+
# latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
337 |
+
image = self.vae.decode(1 / self.vae.scaling_factor * latents)
|
338 |
+
|
339 |
+
image = self.image_processor.postprocess(
|
340 |
+
image, output_type=output_type)
|
341 |
+
|
342 |
+
# Offload last model to CPU
|
343 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
344 |
+
self.final_offload_hook.offload()
|
345 |
+
|
346 |
+
return image
|
lyrasd_model/lyrasdxl_pipeline_base.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
|
12 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
13 |
+
from diffusers.image_processor import VaeImageProcessor
|
14 |
+
from diffusers.models import AutoencoderKL
|
15 |
+
from diffusers.schedulers import (DPMSolverMultistepScheduler,
|
16 |
+
EulerAncestralDiscreteScheduler,
|
17 |
+
EulerDiscreteScheduler,
|
18 |
+
KarrasDiffusionSchedulers)
|
19 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
20 |
+
from diffusers.utils.torch_utils import randn_tensor
|
21 |
+
from diffusers.utils import logging
|
22 |
+
from PIL import Image
|
23 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
24 |
+
from .lyrasd_vae_model import LyraSdVaeModel
|
25 |
+
from .module.lyrasd_ip_adapter import LyraIPAdapter
|
26 |
+
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
|
27 |
+
from safetensors.torch import load_file
|
28 |
+
|
29 |
+
|
30 |
+
class LyraSDXLPipelineBase(TextualInversionLoaderMixin):
|
31 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, num_channels_unet=4, num_channels_latents=4, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None:
|
32 |
+
self.device = device
|
33 |
+
self.dtype = dtype
|
34 |
+
|
35 |
+
self.num_channels_unet = num_channels_unet
|
36 |
+
self.num_channels_latents = num_channels_latents
|
37 |
+
self.vae_scale_factor = vae_scale_factor
|
38 |
+
self.vae_scaling_factor = vae_scaling_factor
|
39 |
+
|
40 |
+
self.unet_cache = {}
|
41 |
+
self.unet_in_channels = 4
|
42 |
+
|
43 |
+
self.controlnet_cache = {}
|
44 |
+
self.controlnet_add_embedding = {}
|
45 |
+
|
46 |
+
self.loaded_lora = {}
|
47 |
+
self.loaded_lora_strength = {}
|
48 |
+
|
49 |
+
self.scheduler = None
|
50 |
+
|
51 |
+
self.init_pipe()
|
52 |
+
|
53 |
+
def init_pipe(self):
|
54 |
+
self.vae = LyraSdVaeModel(
|
55 |
+
scale_factor=self.vae_scale_factor, scaling_factor=self.vae_scaling_factor, is_upcast=True)
|
56 |
+
|
57 |
+
self.unet = torch.classes.lyrasd.XLUnet2dConditionalModelOp(
|
58 |
+
"fp16",
|
59 |
+
self.num_channels_unet,
|
60 |
+
self.num_channels_latents)
|
61 |
+
|
62 |
+
self.default_sample_size = 128
|
63 |
+
self.addition_time_embed_dim = 256
|
64 |
+
flip_sin_to_cos, freq_shift = True, 0
|
65 |
+
self.projection_class_embeddings_input_dim, self.time_embed_dim = 2816, 1280
|
66 |
+
|
67 |
+
self.add_time_proj = Timesteps(
|
68 |
+
self.addition_time_embed_dim, flip_sin_to_cos, freq_shift).to(self.dtype).to(self.device)
|
69 |
+
|
70 |
+
self.add_embedding = TimestepEmbedding(
|
71 |
+
self.projection_class_embeddings_input_dim, self.time_embed_dim).to(self.dtype).to(self.device)
|
72 |
+
|
73 |
+
self.image_processor = VaeImageProcessor(
|
74 |
+
vae_scale_factor=self.vae_scale_factor)
|
75 |
+
|
76 |
+
self.mask_processor = VaeImageProcessor(
|
77 |
+
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
78 |
+
)
|
79 |
+
|
80 |
+
self.control_image_processor = VaeImageProcessor(
|
81 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
82 |
+
)
|
83 |
+
|
84 |
+
self.feature_extractor = CLIPImageProcessor()
|
85 |
+
|
86 |
+
def reload_pipe(self, model_path):
|
87 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
88 |
+
model_path, subfolder="tokenizer")
|
89 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
90 |
+
model_path, subfolder="text_encoder").to(self.dtype).to(self.device)
|
91 |
+
|
92 |
+
self.tokenizer_2 = CLIPTokenizer.from_pretrained(
|
93 |
+
model_path, subfolder="tokenizer_2")
|
94 |
+
self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
|
95 |
+
model_path, subfolder="text_encoder_2").to(self.dtype).to(self.device)
|
96 |
+
|
97 |
+
self.reload_unet_model_v2(model_path)
|
98 |
+
self.reload_vae_model_v2(model_path)
|
99 |
+
|
100 |
+
if not self.scheduler:
|
101 |
+
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
|
102 |
+
model_path, subfolder="scheduler")
|
103 |
+
|
104 |
+
def load_embedding_weight(self, model, weight_path, unet_file_format="fp16"):
|
105 |
+
bin_list = glob(weight_path)
|
106 |
+
sate_dicts = model.state_dict()
|
107 |
+
dtype = np.float32 if unet_file_format == "fp32" else np.float16
|
108 |
+
for bin_file in bin_list:
|
109 |
+
weight = torch.from_numpy(np.fromfile(bin_file, dtype=dtype)).to(
|
110 |
+
self.dtype).to(self.device)
|
111 |
+
key = '.'.join(os.path.basename(bin_file).split('.')[1:-1])
|
112 |
+
weight = weight.reshape(sate_dicts[key].shape)
|
113 |
+
sate_dicts.update({key: weight})
|
114 |
+
model.load_state_dict(sate_dicts)
|
115 |
+
|
116 |
+
@property
|
117 |
+
def _execution_device(self):
|
118 |
+
if not hasattr(self.unet, "_hf_hook"):
|
119 |
+
return self.device
|
120 |
+
for module in self.unet.modules():
|
121 |
+
if (
|
122 |
+
hasattr(module, "_hf_hook")
|
123 |
+
and hasattr(module._hf_hook, "execution_device")
|
124 |
+
and module._hf_hook.execution_device is not None
|
125 |
+
):
|
126 |
+
return torch.device(module._hf_hook.execution_device)
|
127 |
+
return self.device
|
128 |
+
|
129 |
+
def reload_unet_model(self, unet_path, unet_file_format='fp32'):
|
130 |
+
if len(unet_path) > 0 and unet_path[-1] != "/":
|
131 |
+
unet_path = unet_path + "/"
|
132 |
+
self.unet.reload_unet_model(unet_path, unet_file_format)
|
133 |
+
self.load_embedding_weight(
|
134 |
+
self.add_embedding, f"{unet_path}add_embedding*", unet_file_format=unet_file_format)
|
135 |
+
|
136 |
+
def reload_vae_model(self, vae_path, vae_file_format='fp32'):
|
137 |
+
if len(vae_path) > 0 and vae_path[-1] != "/":
|
138 |
+
vae_path = vae_path + "/"
|
139 |
+
return self.vae.reload_vae_model(vae_path, vae_file_format)
|
140 |
+
|
141 |
+
def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'):
|
142 |
+
if len(lora_model_path) > 0 and lora_model_path[-1] != "/":
|
143 |
+
lora_model_path = lora_model_path + "/"
|
144 |
+
lora = add_xltext_lora_layer(
|
145 |
+
self.text_encoder, self.text_encoder_2, lora_model_path, lora_strength, lora_file_format)
|
146 |
+
|
147 |
+
self.loaded_lora[lora_name] = lora
|
148 |
+
self.unet.load_lora(lora_model_path, lora_name,
|
149 |
+
lora_strength, lora_file_format)
|
150 |
+
|
151 |
+
def unload_lora(self, lora_name, clean_cache=False):
|
152 |
+
for layer_data in self.loaded_lora[lora_name]:
|
153 |
+
layer = layer_data['layer']
|
154 |
+
added_weight = layer_data['added_weight']
|
155 |
+
layer.weight.data -= added_weight
|
156 |
+
self.unet.unload_lora(lora_name, clean_cache)
|
157 |
+
del self.loaded_lora[lora_name]
|
158 |
+
gc.collect()
|
159 |
+
torch.cuda.empty_cache()
|
160 |
+
|
161 |
+
def load_lora_v2(self, lora_model_path, lora_name, lora_strength):
|
162 |
+
if lora_name in self.loaded_lora:
|
163 |
+
state_dict = self.loaded_lora[lora_name]
|
164 |
+
else:
|
165 |
+
state_dict = load_state_dict(lora_model_path)
|
166 |
+
self.loaded_lora[lora_name] = state_dict
|
167 |
+
self.loaded_lora_strength[lora_name] = lora_strength
|
168 |
+
add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
|
169 |
+
self.text_encoder_2, lora_strength)
|
170 |
+
|
171 |
+
def unload_lora_v2(self, lora_name, clean_cache=False):
|
172 |
+
state_dict = self.loaded_lora[lora_name]
|
173 |
+
lora_strength = self.loaded_lora_strength[lora_name]
|
174 |
+
add_lora_to_opt_model(state_dict, self.unet, self.text_encoder,
|
175 |
+
self.text_encoder_2, -1.0 * lora_strength)
|
176 |
+
del self.loaded_lora_strength[lora_name]
|
177 |
+
|
178 |
+
if clean_cache:
|
179 |
+
del self.loaded_lora[lora_name]
|
180 |
+
gc.collect()
|
181 |
+
torch.cuda.empty_cache()
|
182 |
+
|
183 |
+
def clean_lora_cache(self):
|
184 |
+
self.unet.clean_lora_cache()
|
185 |
+
|
186 |
+
def get_loaded_lora(self):
|
187 |
+
return self.unet.get_loaded_lora()
|
188 |
+
|
189 |
+
def _get_aug_emb(self, time_ids, text_embeds, dtype):
|
190 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
191 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
192 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
193 |
+
add_embeds = add_embeds.to(dtype)
|
194 |
+
aug_emb = self.add_embedding(add_embeds)
|
195 |
+
return aug_emb
|
196 |
+
|
197 |
+
def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True):
|
198 |
+
self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path,
|
199 |
+
num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim)
|
200 |
+
|
201 |
+
def reload_unet_model_v2(self, model_path):
|
202 |
+
checkpoint_file = os.path.join(
|
203 |
+
model_path, "unet/diffusion_pytorch_model.bin")
|
204 |
+
if not os.path.exists(checkpoint_file):
|
205 |
+
checkpoint_file = os.path.join(
|
206 |
+
model_path, "unet/diffusion_pytorch_model.safetensors")
|
207 |
+
if checkpoint_file in self.unet_cache:
|
208 |
+
state_dict = self.unet_cache[checkpoint_file]
|
209 |
+
else:
|
210 |
+
if "safetensors" in checkpoint_file:
|
211 |
+
state_dict = load_file(checkpoint_file)
|
212 |
+
else:
|
213 |
+
state_dict = torch.load(checkpoint_file, map_location="cpu")
|
214 |
+
|
215 |
+
for key in state_dict:
|
216 |
+
if len(state_dict[key].shape) == 4:
|
217 |
+
# converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
|
218 |
+
state_dict[key] = state_dict[key].to(
|
219 |
+
torch.float16).permute(0, 2, 3, 1).contiguous()
|
220 |
+
state_dict[key] = state_dict[key].to(torch.float16)
|
221 |
+
self.unet_cache[checkpoint_file] = state_dict
|
222 |
+
|
223 |
+
self.unet.reload_unet_model_from_cache(state_dict, "cpu")
|
224 |
+
self.load_embedding_weight_v2(self.add_embedding, state_dict)
|
225 |
+
|
226 |
+
def load_embedding_weight_v2(self, model, state_dict):
|
227 |
+
sub_state_dict = {}
|
228 |
+
for k in state_dict:
|
229 |
+
if k.startswith("add_embedding"):
|
230 |
+
v = state_dict[k]
|
231 |
+
sub_k = ".".join(k.split(".")[1:])
|
232 |
+
sub_state_dict[sub_k] = v
|
233 |
+
|
234 |
+
model.load_state_dict(sub_state_dict)
|
235 |
+
|
236 |
+
def reload_vae_model_v2(self, model_path):
|
237 |
+
self.vae.reload_vae_model_v2(model_path)
|
238 |
+
|
239 |
+
def load_controlnet_model_v2(self, model_name, controlnet_path):
|
240 |
+
checkpoint_file = os.path.join(
|
241 |
+
controlnet_path, "diffusion_pytorch_model.bin")
|
242 |
+
if not os.path.exists(checkpoint_file):
|
243 |
+
checkpoint_file = os.path.join(
|
244 |
+
controlnet_path, "diffusion_pytorch_model.safetensors")
|
245 |
+
if checkpoint_file in self.controlnet_cache:
|
246 |
+
state_dict = self.controlnet_cache[checkpoint_file]
|
247 |
+
else:
|
248 |
+
if "safetensors" in checkpoint_file:
|
249 |
+
state_dict = load_file(checkpoint_file)
|
250 |
+
else:
|
251 |
+
state_dict = torch.load(checkpoint_file, map_location="cpu")
|
252 |
+
|
253 |
+
for key in state_dict:
|
254 |
+
if len(state_dict[key].shape) == 4:
|
255 |
+
# converted_unet_checkpoint[key] = converted_unet_checkpoint[key].to(torch.float16).to("cuda").permute(0,2,3,1).contiguous().cpu()
|
256 |
+
state_dict[key] = state_dict[key].to(
|
257 |
+
torch.float16).permute(0, 2, 3, 1).contiguous()
|
258 |
+
state_dict[key] = state_dict[key].to(torch.float16)
|
259 |
+
self.controlnet_cache[checkpoint_file] = state_dict
|
260 |
+
|
261 |
+
self.unet.load_controlnet_model_from_state_dict(
|
262 |
+
model_name, state_dict, "cpu")
|
263 |
+
|
264 |
+
add_embedding = TimestepEmbedding(
|
265 |
+
self.projection_class_embeddings_input_dim, self.time_embed_dim).to(self.dtype).to(self.device)
|
266 |
+
|
267 |
+
self.load_embedding_weight_v2(add_embedding, state_dict)
|
268 |
+
self.controlnet_add_embedding[model_name] = add_embedding
|
269 |
+
|
270 |
+
def unload_controlnet_model(self, model_name):
|
271 |
+
self.unet.unload_controlnet_model(model_name, True)
|
272 |
+
del self.controlnet_add_embedding[model_name]
|
273 |
+
|
274 |
+
def get_loaded_controlnet(self):
|
275 |
+
return self.unet.get_loaded_controlnet()
|
lyrasd_model/lyrasdxl_txt2img_inpaint_pipeline.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
from diffusers import StableDiffusionXLInpaintPipeline, UNet2DConditionModel
|
12 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
13 |
+
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
|
14 |
+
from diffusers.models import AutoencoderKL
|
15 |
+
from diffusers.schedulers import (DPMSolverMultistepScheduler,
|
16 |
+
EulerAncestralDiscreteScheduler,
|
17 |
+
EulerDiscreteScheduler,
|
18 |
+
KarrasDiffusionSchedulers)
|
19 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
20 |
+
from diffusers.utils.torch_utils import randn_tensor
|
21 |
+
from diffusers.utils import logging
|
22 |
+
from PIL import Image
|
23 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
24 |
+
from .lyrasd_vae_model import LyraSdVaeModel
|
25 |
+
from .module.lyrasd_ip_adapter import LyraIPAdapter
|
26 |
+
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
|
27 |
+
from safetensors.torch import load_file
|
28 |
+
|
29 |
+
from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase
|
30 |
+
|
31 |
+
|
32 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
33 |
+
"""
|
34 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
35 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
36 |
+
"""
|
37 |
+
std_text = noise_pred_text.std(
|
38 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
39 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
40 |
+
# rescale the results from guidance (fixes overexposure)
|
41 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
42 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
43 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + \
|
44 |
+
(1 - guidance_rescale) * noise_cfg
|
45 |
+
return noise_cfg
|
46 |
+
|
47 |
+
def retrieve_latents(
|
48 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
49 |
+
):
|
50 |
+
if sample_mode == "sample":
|
51 |
+
return encoder_output.sample(generator)
|
52 |
+
elif sample_mode == "argmax":
|
53 |
+
return encoder_output.mode()
|
54 |
+
else:
|
55 |
+
return encoder_output
|
56 |
+
|
57 |
+
|
58 |
+
def retrieve_timesteps(
|
59 |
+
scheduler,
|
60 |
+
num_inference_steps: Optional[int] = None,
|
61 |
+
device: Optional[Union[str, torch.device]] = None,
|
62 |
+
timesteps: Optional[List[int]] = None,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
"""
|
66 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
67 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
scheduler (`SchedulerMixin`):
|
71 |
+
The scheduler to get timesteps from.
|
72 |
+
num_inference_steps (`int`):
|
73 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
74 |
+
`timesteps` must be `None`.
|
75 |
+
device (`str` or `torch.device`, *optional*):
|
76 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
77 |
+
timesteps (`List[int]`, *optional*):
|
78 |
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
79 |
+
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
80 |
+
must be `None`.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
84 |
+
second element is the number of inference steps.
|
85 |
+
"""
|
86 |
+
if timesteps is not None:
|
87 |
+
accepts_timesteps = "timesteps" in set(
|
88 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys())
|
89 |
+
if not accepts_timesteps:
|
90 |
+
raise ValueError(
|
91 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
92 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
93 |
+
)
|
94 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
95 |
+
timesteps = scheduler.timesteps
|
96 |
+
num_inference_steps = len(timesteps)
|
97 |
+
else:
|
98 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
99 |
+
timesteps = scheduler.timesteps
|
100 |
+
return timesteps, num_inference_steps
|
101 |
+
|
102 |
+
|
103 |
+
class LyraSdXLTxt2ImgInpaintPipeline(LyraSDXLPipelineBase, StableDiffusionXLInpaintPipeline):
|
104 |
+
device = torch.device("cpu")
|
105 |
+
dtype = torch.float32
|
106 |
+
|
107 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025, num_channels_unet=9, num_channels_latents=4, requires_aesthetics_score: bool = False,
|
108 |
+
force_zeros_for_empty_prompt: bool = True) -> None:
|
109 |
+
self.register_to_config(
|
110 |
+
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
111 |
+
self.register_to_config(
|
112 |
+
requires_aesthetics_score=requires_aesthetics_score)
|
113 |
+
|
114 |
+
super().__init__(device, dtype, num_channels_unet=num_channels_unet, num_channels_latents=num_channels_latents, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
|
115 |
+
|
116 |
+
|
117 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
118 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
119 |
+
if not isinstance(image, torch.Tensor):
|
120 |
+
image = self.feature_extractor(
|
121 |
+
image, return_tensors="pt").pixel_values
|
122 |
+
|
123 |
+
image = image.to(device=device, dtype=dtype)
|
124 |
+
image_embeds = self.image_encoder(image).image_embeds
|
125 |
+
image_embeds = image_embeds.repeat_interleave(
|
126 |
+
num_images_per_prompt, dim=0)
|
127 |
+
|
128 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
129 |
+
return image_embeds, uncond_image_embeds
|
130 |
+
|
131 |
+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
132 |
+
dtype = image.dtype
|
133 |
+
# if self.vae.config.force_upcast:
|
134 |
+
# image = image.float()
|
135 |
+
# self.vae.to(dtype=torch.float32)
|
136 |
+
|
137 |
+
if isinstance(generator, list):
|
138 |
+
image_latents = [
|
139 |
+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
140 |
+
for i in range(image.shape[0])
|
141 |
+
]
|
142 |
+
image_latents = torch.cat(image_latents, dim=0)
|
143 |
+
else:
|
144 |
+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
145 |
+
|
146 |
+
image_latents = image_latents.to(dtype)
|
147 |
+
image_latents = self.vae.scaling_factor * image_latents
|
148 |
+
|
149 |
+
return image_latents
|
150 |
+
|
151 |
+
def _get_add_time_ids(
|
152 |
+
self,
|
153 |
+
original_size,
|
154 |
+
crops_coords_top_left,
|
155 |
+
target_size,
|
156 |
+
aesthetic_score,
|
157 |
+
negative_aesthetic_score,
|
158 |
+
negative_original_size,
|
159 |
+
negative_crops_coords_top_left,
|
160 |
+
negative_target_size,
|
161 |
+
dtype,
|
162 |
+
text_encoder_projection_dim=None,
|
163 |
+
):
|
164 |
+
if self.config.requires_aesthetics_score:
|
165 |
+
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
166 |
+
add_neg_time_ids = list(
|
167 |
+
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
171 |
+
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
172 |
+
|
173 |
+
passed_add_embed_dim = (
|
174 |
+
self.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
175 |
+
)
|
176 |
+
expected_add_embed_dim = self.add_embedding.linear_1.in_features
|
177 |
+
|
178 |
+
if (
|
179 |
+
expected_add_embed_dim > passed_add_embed_dim
|
180 |
+
and (expected_add_embed_dim - passed_add_embed_dim) == self.addition_time_embed_dim
|
181 |
+
):
|
182 |
+
raise ValueError(
|
183 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
|
184 |
+
)
|
185 |
+
elif (
|
186 |
+
expected_add_embed_dim < passed_add_embed_dim
|
187 |
+
and (passed_add_embed_dim - expected_add_embed_dim) == self.addition_time_embed_dim
|
188 |
+
):
|
189 |
+
raise ValueError(
|
190 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
|
191 |
+
)
|
192 |
+
elif expected_add_embed_dim != passed_add_embed_dim:
|
193 |
+
raise ValueError(
|
194 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
195 |
+
)
|
196 |
+
|
197 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
198 |
+
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
|
199 |
+
|
200 |
+
return add_time_ids, add_neg_time_ids
|
201 |
+
|
202 |
+
def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True):
|
203 |
+
self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path,
|
204 |
+
num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim)
|
205 |
+
|
206 |
+
@torch.no_grad()
|
207 |
+
def __call__(
|
208 |
+
self,
|
209 |
+
prompt: Union[str, List[str]] = None,
|
210 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
211 |
+
image: PipelineImageInput = None,
|
212 |
+
mask_image: PipelineImageInput = None,
|
213 |
+
masked_image_latents: torch.FloatTensor = None,
|
214 |
+
height: Optional[int] = None,
|
215 |
+
width: Optional[int] = None,
|
216 |
+
strength: float = 0.9999,
|
217 |
+
num_inference_steps: int = 50,
|
218 |
+
timesteps: List[int] = None,
|
219 |
+
denoising_start: Optional[float] = None,
|
220 |
+
denoising_end: Optional[float] = None,
|
221 |
+
guidance_scale: float = 7.5,
|
222 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
223 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
224 |
+
num_images_per_prompt: Optional[int] = 1,
|
225 |
+
eta: float = 0.0,
|
226 |
+
generator: Optional[Union[torch.Generator,
|
227 |
+
List[torch.Generator]]] = None,
|
228 |
+
latents: Optional[torch.FloatTensor] = None,
|
229 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
230 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
231 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
232 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
233 |
+
output_type: Optional[str] = "pil",
|
234 |
+
return_dict: bool = True,
|
235 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
236 |
+
guidance_rescale: float = 0.0,
|
237 |
+
original_size: Tuple[int, int] = None,
|
238 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
239 |
+
target_size: Tuple[int, int] = None,
|
240 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
241 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
242 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
243 |
+
aesthetic_score: float = 6.0,
|
244 |
+
negative_aesthetic_score: float = 2.5,
|
245 |
+
clip_skip: Optional[int] = None,
|
246 |
+
extra_tensor_dict: Optional[Dict[str, torch.FloatTensor]] = {},
|
247 |
+
param_scale_dict: Optional[Dict[str, int]] = {},
|
248 |
+
**kwargs
|
249 |
+
):
|
250 |
+
|
251 |
+
callback = kwargs.pop("callback", None)
|
252 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
253 |
+
|
254 |
+
# 0. Default height and width to unet
|
255 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
256 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
257 |
+
|
258 |
+
original_size = original_size or (height, width)
|
259 |
+
target_size = target_size or (height, width)
|
260 |
+
|
261 |
+
self._guidance_scale = guidance_scale
|
262 |
+
self._guidance_rescale = guidance_rescale
|
263 |
+
self._clip_skip = clip_skip
|
264 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
265 |
+
self._denoising_end = denoising_end
|
266 |
+
self._denoising_start = denoising_start
|
267 |
+
|
268 |
+
# 1. Check inputs. Raise error if not correct
|
269 |
+
self.check_inputs(
|
270 |
+
prompt,
|
271 |
+
prompt_2,
|
272 |
+
height,
|
273 |
+
width,
|
274 |
+
strength,
|
275 |
+
callback_steps,
|
276 |
+
negative_prompt,
|
277 |
+
negative_prompt_2,
|
278 |
+
prompt_embeds,
|
279 |
+
negative_prompt_embeds,
|
280 |
+
)
|
281 |
+
|
282 |
+
# 2. Define call parameters
|
283 |
+
if prompt is not None and isinstance(prompt, str):
|
284 |
+
batch_size = 1
|
285 |
+
elif prompt is not None and isinstance(prompt, list):
|
286 |
+
batch_size = len(prompt)
|
287 |
+
else:
|
288 |
+
batch_size = prompt_embeds.shape[0]
|
289 |
+
|
290 |
+
device = self._execution_device
|
291 |
+
|
292 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
293 |
+
|
294 |
+
# 3. Encode input prompt
|
295 |
+
text_encoder_lora_scale = (
|
296 |
+
cross_attention_kwargs.get(
|
297 |
+
"scale", None) if cross_attention_kwargs is not None else None
|
298 |
+
)
|
299 |
+
(
|
300 |
+
prompt_embeds,
|
301 |
+
negative_prompt_embeds,
|
302 |
+
pooled_prompt_embeds,
|
303 |
+
negative_pooled_prompt_embeds,
|
304 |
+
) = self.encode_prompt(
|
305 |
+
prompt=prompt,
|
306 |
+
prompt_2=prompt_2,
|
307 |
+
device=device,
|
308 |
+
num_images_per_prompt=num_images_per_prompt,
|
309 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
310 |
+
negative_prompt=negative_prompt,
|
311 |
+
negative_prompt_2=negative_prompt_2,
|
312 |
+
prompt_embeds=prompt_embeds,
|
313 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
314 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
315 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
316 |
+
lora_scale=text_encoder_lora_scale,
|
317 |
+
clip_skip=clip_skip
|
318 |
+
)
|
319 |
+
|
320 |
+
def denoising_value_valid(dnv):
|
321 |
+
return isinstance(self.denoising_end, float) and 0 < dnv < 1
|
322 |
+
|
323 |
+
# 4. Prepare timesteps
|
324 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
325 |
+
self.scheduler, num_inference_steps, device, timesteps)
|
326 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
327 |
+
num_inference_steps,
|
328 |
+
strength,
|
329 |
+
device,
|
330 |
+
denoising_start=self.denoising_start if denoising_value_valid else None,
|
331 |
+
)
|
332 |
+
|
333 |
+
latent_timestep = timesteps[:1].repeat(
|
334 |
+
batch_size * num_images_per_prompt)
|
335 |
+
is_strength_max = strength == 1.0
|
336 |
+
|
337 |
+
# 5. Prepare latent variables
|
338 |
+
|
339 |
+
init_image = self.image_processor.preprocess(
|
340 |
+
image, height=height, width=width)
|
341 |
+
init_image = init_image.to(dtype=torch.float32)
|
342 |
+
|
343 |
+
mask = self.mask_processor.preprocess(
|
344 |
+
mask_image, height=height, width=width)
|
345 |
+
|
346 |
+
if masked_image_latents is not None:
|
347 |
+
masked_image = masked_image_latents
|
348 |
+
elif init_image.shape[1] == 4:
|
349 |
+
# if images are in latent space, we can't mask it
|
350 |
+
masked_image = None
|
351 |
+
else:
|
352 |
+
masked_image = init_image * (mask < 0.5)
|
353 |
+
|
354 |
+
add_noise = True if self.denoising_start is None else False
|
355 |
+
|
356 |
+
return_image_latents = self.num_channels_unet == 4
|
357 |
+
|
358 |
+
latents_outputs = self.prepare_latents(
|
359 |
+
batch_size * num_images_per_prompt,
|
360 |
+
self.num_channels_latents,
|
361 |
+
height,
|
362 |
+
width,
|
363 |
+
prompt_embeds.dtype,
|
364 |
+
device,
|
365 |
+
generator,
|
366 |
+
latents,
|
367 |
+
image=init_image,
|
368 |
+
timestep=latent_timestep,
|
369 |
+
is_strength_max=is_strength_max,
|
370 |
+
add_noise=add_noise,
|
371 |
+
return_noise=True,
|
372 |
+
return_image_latents=return_image_latents,
|
373 |
+
)
|
374 |
+
|
375 |
+
if return_image_latents:
|
376 |
+
latents, noise, image_latents = latents_outputs
|
377 |
+
else:
|
378 |
+
latents, noise = latents_outputs
|
379 |
+
|
380 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
381 |
+
mask,
|
382 |
+
masked_image,
|
383 |
+
batch_size * num_images_per_prompt,
|
384 |
+
height,
|
385 |
+
width,
|
386 |
+
prompt_embeds.dtype,
|
387 |
+
device,
|
388 |
+
generator,
|
389 |
+
do_classifier_free_guidance,
|
390 |
+
)
|
391 |
+
|
392 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
393 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
394 |
+
|
395 |
+
# 7. Prepare added time ids & embeddings
|
396 |
+
add_text_embeds = pooled_prompt_embeds
|
397 |
+
if self.text_encoder_2 is None:
|
398 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
399 |
+
else:
|
400 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
401 |
+
|
402 |
+
if negative_original_size is None:
|
403 |
+
negative_original_size = original_size
|
404 |
+
if negative_target_size is None:
|
405 |
+
negative_target_size = target_size
|
406 |
+
|
407 |
+
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
408 |
+
original_size,
|
409 |
+
crops_coords_top_left,
|
410 |
+
target_size,
|
411 |
+
aesthetic_score,
|
412 |
+
negative_aesthetic_score,
|
413 |
+
negative_original_size,
|
414 |
+
negative_crops_coords_top_left,
|
415 |
+
negative_target_size,
|
416 |
+
dtype=prompt_embeds.dtype,
|
417 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
418 |
+
)
|
419 |
+
add_time_ids = add_time_ids.repeat(
|
420 |
+
batch_size * num_images_per_prompt, 1)
|
421 |
+
|
422 |
+
if do_classifier_free_guidance:
|
423 |
+
prompt_embeds = torch.cat(
|
424 |
+
[negative_prompt_embeds, prompt_embeds], dim=0)
|
425 |
+
add_text_embeds = torch.cat(
|
426 |
+
[negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
427 |
+
add_neg_time_ids = add_neg_time_ids.repeat(
|
428 |
+
batch_size * num_images_per_prompt, 1)
|
429 |
+
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
430 |
+
|
431 |
+
prompt_embeds = prompt_embeds.to(device)
|
432 |
+
add_text_embeds = add_text_embeds.to(device)
|
433 |
+
add_time_ids = add_time_ids.to(device)
|
434 |
+
|
435 |
+
# 8. Denoising loop
|
436 |
+
num_warmup_steps = max(
|
437 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
438 |
+
|
439 |
+
# 7.1 Apply denoising_end
|
440 |
+
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
441 |
+
discrete_timestep_cutoff = int(
|
442 |
+
round(
|
443 |
+
self.scheduler.config.num_train_timesteps
|
444 |
+
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
445 |
+
)
|
446 |
+
)
|
447 |
+
num_inference_steps = len(
|
448 |
+
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
449 |
+
timesteps = timesteps[:num_inference_steps]
|
450 |
+
|
451 |
+
aug_emb = self._get_aug_emb(
|
452 |
+
add_time_ids, add_text_embeds, prompt_embeds.dtype)
|
453 |
+
|
454 |
+
extra_tensor_dict2 = {}
|
455 |
+
for name in extra_tensor_dict:
|
456 |
+
if name in ["fp_hidden_states", "ip_hidden_states"]:
|
457 |
+
v1, v2 = extra_tensor_dict[name][0], extra_tensor_dict[name][1]
|
458 |
+
extra_tensor_dict2[name] = torch.cat(
|
459 |
+
[v1.repeat(num_images_per_prompt, 1, 1), v2.repeat(num_images_per_prompt, 1, 1)])
|
460 |
+
else:
|
461 |
+
extra_tensor_dict2[name] = extra_tensor_dict[name]
|
462 |
+
|
463 |
+
# np.save("/workspace/prompt_embeds.npy", prompt_embeds.detach().cpu().numpy())
|
464 |
+
# prompt_embeds = torch.from_numpy(np.load("/workspace/gt_prompt_embeds.npy")).cuda()
|
465 |
+
self._num_timesteps = len(timesteps)
|
466 |
+
|
467 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
468 |
+
for i, t in enumerate(timesteps):
|
469 |
+
# expand the latents if we are doing classifier free guidance
|
470 |
+
latent_model_input = torch.cat(
|
471 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
472 |
+
|
473 |
+
latent_model_input = self.scheduler.scale_model_input(
|
474 |
+
latent_model_input, t)
|
475 |
+
|
476 |
+
if self.num_channels_unet == 9:
|
477 |
+
latent_model_input = torch.cat(
|
478 |
+
[latent_model_input, mask, masked_image_latents], dim=1)
|
479 |
+
|
480 |
+
latent_model_input = latent_model_input.permute(
|
481 |
+
0, 2, 3, 1).contiguous()
|
482 |
+
|
483 |
+
noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, aug_emb, None, None,
|
484 |
+
None, None, None, extra_tensor_dict2, param_scale_dict).permute(0, 3, 1, 2).contiguous()
|
485 |
+
|
486 |
+
# perform guidance
|
487 |
+
if do_classifier_free_guidance:
|
488 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
489 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * \
|
490 |
+
(noise_pred_text - noise_pred_uncond)
|
491 |
+
|
492 |
+
if do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
493 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
494 |
+
noise_pred = rescale_noise_cfg(
|
495 |
+
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
496 |
+
|
497 |
+
# compute the previous noisy sample x_t -> x_t-1
|
498 |
+
latents = self.scheduler.step(
|
499 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
500 |
+
|
501 |
+
if self.num_channels_unet == 4:
|
502 |
+
init_latents_proper = image_latents
|
503 |
+
if do_classifier_free_guidance:
|
504 |
+
init_mask, _ = mask.chunk(2)
|
505 |
+
else:
|
506 |
+
init_mask = mask
|
507 |
+
|
508 |
+
if i < len(timesteps) - 1:
|
509 |
+
noise_timestep = timesteps[i + 1]
|
510 |
+
init_latents_proper = self.scheduler.add_noise(
|
511 |
+
init_latents_proper, noise, torch.tensor(
|
512 |
+
[noise_timestep])
|
513 |
+
)
|
514 |
+
|
515 |
+
latents = (1 - init_mask) * \
|
516 |
+
init_latents_proper + init_mask * latents
|
517 |
+
|
518 |
+
# call the callback, if provided
|
519 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
520 |
+
progress_bar.update()
|
521 |
+
if callback is not None and i % callback_steps == 0:
|
522 |
+
callback(i, t, latents)
|
523 |
+
|
524 |
+
if output_type == "latent":
|
525 |
+
return latents
|
526 |
+
|
527 |
+
image = self.vae.decode(1 / self.vae.scaling_factor * latents)
|
528 |
+
image = self.image_processor.postprocess(
|
529 |
+
image, output_type=output_type)
|
530 |
+
|
531 |
+
# Offload last model to CPU
|
532 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
533 |
+
self.final_offload_hook.offload()
|
534 |
+
|
535 |
+
return image
|
lyrasd_model/lyrasdxl_txt2img_pipeline.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
5 |
+
|
6 |
+
import gc
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
|
12 |
+
from diffusers.loaders import TextualInversionLoaderMixin
|
13 |
+
from diffusers.image_processor import VaeImageProcessor
|
14 |
+
from diffusers.models import AutoencoderKL
|
15 |
+
from diffusers.schedulers import (DPMSolverMultistepScheduler,
|
16 |
+
EulerAncestralDiscreteScheduler,
|
17 |
+
EulerDiscreteScheduler,
|
18 |
+
KarrasDiffusionSchedulers)
|
19 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
20 |
+
from diffusers.utils.torch_utils import randn_tensor
|
21 |
+
from diffusers.utils import logging
|
22 |
+
from PIL import Image
|
23 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
24 |
+
from .lyrasd_vae_model import LyraSdVaeModel
|
25 |
+
from .module.lyrasd_ip_adapter import LyraIPAdapter
|
26 |
+
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict
|
27 |
+
from safetensors.torch import load_file
|
28 |
+
from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase
|
29 |
+
|
30 |
+
|
31 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
32 |
+
"""
|
33 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
34 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
35 |
+
"""
|
36 |
+
std_text = noise_pred_text.std(
|
37 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
38 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
39 |
+
# rescale the results from guidance (fixes overexposure)
|
40 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
41 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
42 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + \
|
43 |
+
(1 - guidance_rescale) * noise_cfg
|
44 |
+
return noise_cfg
|
45 |
+
|
46 |
+
|
47 |
+
class LyraSdXLTxt2ImgPipeline(LyraSDXLPipelineBase, StableDiffusionXLPipeline):
|
48 |
+
device = torch.device("cpu")
|
49 |
+
dtype = torch.float32
|
50 |
+
|
51 |
+
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None:
|
52 |
+
self.register_to_config(force_zeros_for_empty_prompt=True)
|
53 |
+
|
54 |
+
super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor)
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def __call__(
|
58 |
+
self,
|
59 |
+
prompt: Union[str, List[str]] = None,
|
60 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
61 |
+
height: Optional[int] = None,
|
62 |
+
width: Optional[int] = None,
|
63 |
+
num_inference_steps: int = 50,
|
64 |
+
denoising_end: Optional[float] = None,
|
65 |
+
guidance_scale: float = 5.0,
|
66 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
67 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
68 |
+
num_images_per_prompt: Optional[int] = 1,
|
69 |
+
eta: float = 0.0,
|
70 |
+
generator: Optional[Union[torch.Generator,
|
71 |
+
List[torch.Generator]]] = None,
|
72 |
+
latents: Optional[torch.FloatTensor] = None,
|
73 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
74 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
75 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
76 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
77 |
+
output_type: Optional[str] = "pil",
|
78 |
+
return_dict: bool = True,
|
79 |
+
callback: Optional[Callable[[
|
80 |
+
int, int, torch.FloatTensor], None]] = None,
|
81 |
+
callback_steps: int = 1,
|
82 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
83 |
+
guidance_rescale: float = 0.0,
|
84 |
+
original_size: Optional[Tuple[int, int]] = None,
|
85 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
86 |
+
target_size: Optional[Tuple[int, int]] = None,
|
87 |
+
extra_tensor_dict: Optional[Dict[str, torch.FloatTensor]] = {},
|
88 |
+
param_scale_dict: Optional[Dict[str, int]] = {},
|
89 |
+
clip_skip: Optional[int] = None
|
90 |
+
):
|
91 |
+
|
92 |
+
# 0. Default height and width to unet
|
93 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
94 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
95 |
+
|
96 |
+
original_size = original_size or (height, width)
|
97 |
+
target_size = target_size or (height, width)
|
98 |
+
|
99 |
+
# 1. Check inputs. Raise error if not correct
|
100 |
+
self.check_inputs(
|
101 |
+
prompt,
|
102 |
+
prompt_2,
|
103 |
+
height,
|
104 |
+
width,
|
105 |
+
callback_steps,
|
106 |
+
negative_prompt,
|
107 |
+
negative_prompt_2,
|
108 |
+
prompt_embeds,
|
109 |
+
negative_prompt_embeds,
|
110 |
+
pooled_prompt_embeds,
|
111 |
+
negative_pooled_prompt_embeds,
|
112 |
+
)
|
113 |
+
|
114 |
+
# 2. Define call parameters
|
115 |
+
if prompt is not None and isinstance(prompt, str):
|
116 |
+
batch_size = 1
|
117 |
+
elif prompt is not None and isinstance(prompt, list):
|
118 |
+
batch_size = len(prompt)
|
119 |
+
else:
|
120 |
+
batch_size = prompt_embeds.shape[0]
|
121 |
+
|
122 |
+
device = self._execution_device
|
123 |
+
|
124 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
125 |
+
|
126 |
+
# 3. Encode input prompt
|
127 |
+
text_encoder_lora_scale = (
|
128 |
+
cross_attention_kwargs.get(
|
129 |
+
"scale", None) if cross_attention_kwargs is not None else None
|
130 |
+
)
|
131 |
+
(
|
132 |
+
prompt_embeds,
|
133 |
+
negative_prompt_embeds,
|
134 |
+
pooled_prompt_embeds,
|
135 |
+
negative_pooled_prompt_embeds,
|
136 |
+
) = self.encode_prompt(
|
137 |
+
prompt=prompt,
|
138 |
+
prompt_2=prompt_2,
|
139 |
+
device=device,
|
140 |
+
num_images_per_prompt=num_images_per_prompt,
|
141 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
142 |
+
negative_prompt=negative_prompt,
|
143 |
+
negative_prompt_2=negative_prompt_2,
|
144 |
+
prompt_embeds=prompt_embeds,
|
145 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
146 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
147 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
148 |
+
lora_scale=text_encoder_lora_scale,
|
149 |
+
clip_skip=clip_skip
|
150 |
+
)
|
151 |
+
|
152 |
+
# 4. Prepare timesteps
|
153 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
154 |
+
|
155 |
+
timesteps = self.scheduler.timesteps
|
156 |
+
|
157 |
+
# 5. Prepare latent variables
|
158 |
+
num_channels_latents = self.unet_in_channels
|
159 |
+
latents = self.prepare_latents(
|
160 |
+
batch_size * num_images_per_prompt,
|
161 |
+
num_channels_latents,
|
162 |
+
height,
|
163 |
+
width,
|
164 |
+
prompt_embeds.dtype,
|
165 |
+
device,
|
166 |
+
generator,
|
167 |
+
latents,
|
168 |
+
)
|
169 |
+
|
170 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
171 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
172 |
+
|
173 |
+
# 7. Prepare added time ids & embeddings
|
174 |
+
add_text_embeds = pooled_prompt_embeds
|
175 |
+
add_time_ids = list(
|
176 |
+
original_size + crops_coords_top_left + target_size)
|
177 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
178 |
+
|
179 |
+
if do_classifier_free_guidance:
|
180 |
+
prompt_embeds = torch.cat(
|
181 |
+
[negative_prompt_embeds, prompt_embeds], dim=0)
|
182 |
+
add_text_embeds = torch.cat(
|
183 |
+
[negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
184 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
185 |
+
|
186 |
+
prompt_embeds = prompt_embeds.to(device)
|
187 |
+
add_text_embeds = add_text_embeds.to(device)
|
188 |
+
add_time_ids = add_time_ids.to(device).repeat(
|
189 |
+
batch_size * num_images_per_prompt, 1)
|
190 |
+
|
191 |
+
# 8. Denoising loop
|
192 |
+
num_warmup_steps = max(
|
193 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
194 |
+
|
195 |
+
# 7.1 Apply denoising_end
|
196 |
+
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
|
197 |
+
discrete_timestep_cutoff = int(
|
198 |
+
round(
|
199 |
+
self.scheduler.config.num_train_timesteps
|
200 |
+
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
201 |
+
)
|
202 |
+
)
|
203 |
+
num_inference_steps = len(
|
204 |
+
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
205 |
+
timesteps = timesteps[:num_inference_steps]
|
206 |
+
|
207 |
+
aug_emb = self._get_aug_emb(
|
208 |
+
add_time_ids, add_text_embeds, prompt_embeds.dtype)
|
209 |
+
|
210 |
+
extra_tensor_dict2 = {}
|
211 |
+
for name in extra_tensor_dict:
|
212 |
+
if name in ["fp_hidden_states", "ip_hidden_states"]:
|
213 |
+
v1, v2 = extra_tensor_dict[name][0], extra_tensor_dict[name][1]
|
214 |
+
extra_tensor_dict2[name] = torch.cat(
|
215 |
+
[v1.repeat(num_images_per_prompt, 1, 1), v2.repeat(num_images_per_prompt, 1, 1)])
|
216 |
+
else:
|
217 |
+
extra_tensor_dict2[name] = extra_tensor_dict[name]
|
218 |
+
|
219 |
+
# np.save("/workspace/prompt_embeds.npy", prompt_embeds.detach().cpu().numpy())
|
220 |
+
# prompt_embeds = torch.from_numpy(np.load("/workspace/gt_prompt_embeds.npy")).cuda()
|
221 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
222 |
+
for i, t in enumerate(timesteps):
|
223 |
+
# expand the latents if we are doing classifier free guidance
|
224 |
+
latent_model_input = torch.cat(
|
225 |
+
[latents] * 2) if do_classifier_free_guidance else latents
|
226 |
+
|
227 |
+
latent_model_input = self.scheduler.scale_model_input(
|
228 |
+
latent_model_input, t)
|
229 |
+
latent_model_input = latent_model_input.permute(
|
230 |
+
0, 2, 3, 1).contiguous()
|
231 |
+
|
232 |
+
noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, aug_emb, None, None,
|
233 |
+
None, None, None, extra_tensor_dict2, param_scale_dict).permute(0, 3, 1, 2).contiguous()
|
234 |
+
|
235 |
+
# perform guidance
|
236 |
+
if do_classifier_free_guidance:
|
237 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
238 |
+
noise_pred = noise_pred_uncond + guidance_scale * \
|
239 |
+
(noise_pred_text - noise_pred_uncond)
|
240 |
+
|
241 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
242 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
243 |
+
noise_pred = rescale_noise_cfg(
|
244 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
245 |
+
|
246 |
+
# compute the previous noisy sample x_t -> x_t-1
|
247 |
+
latents = self.scheduler.step(
|
248 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
249 |
+
|
250 |
+
# call the callback, if provided
|
251 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
252 |
+
progress_bar.update()
|
253 |
+
if callback is not None and i % callback_steps == 0:
|
254 |
+
callback(i, t, latents)
|
255 |
+
|
256 |
+
if output_type == "latent":
|
257 |
+
return latents
|
258 |
+
|
259 |
+
image = self.vae.decode(1 / self.vae.scaling_factor * latents)
|
260 |
+
image = self.image_processor.postprocess(
|
261 |
+
image, output_type=output_type)
|
262 |
+
|
263 |
+
# Offload last model to CPU
|
264 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
265 |
+
self.final_offload_hook.offload()
|
266 |
+
|
267 |
+
return image
|
lyrasd_model/{lyrasd_lib/placeholder.txt → module/__init__.py}
RENAMED
File without changes
|
lyrasd_model/module/lyra_tool.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
|
3 |
+
def load_yaml(cfg_path):
|
4 |
+
with open(cfg_path, 'r', encoding='utf-8') as f:
|
5 |
+
return yaml.safe_load(f)
|
lyrasd_model/module/lyrasd_ip_adapter.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers import StableDiffusionPipeline
|
6 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
7 |
+
from diffusers.models.embeddings import ImageProjection
|
8 |
+
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
|
9 |
+
from PIL import Image
|
10 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
11 |
+
from copy import deepcopy
|
12 |
+
import time
|
13 |
+
sys.path.append(os.path.dirname(__file__))
|
14 |
+
from resampler import Resampler
|
15 |
+
from diffusers import DiffusionPipeline
|
16 |
+
import numpy as np
|
17 |
+
# sys.path.append(os.environ['LYRASD_WORKDIR'] + "/tests/utils")
|
18 |
+
from .tools import get_mem_use
|
19 |
+
|
20 |
+
class ImageProjModel(torch.nn.Module):
|
21 |
+
"""Projection Model"""
|
22 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.cross_attention_dim = cross_attention_dim
|
26 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
27 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
28 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
29 |
+
|
30 |
+
def forward(self, image_embeds):
|
31 |
+
embeds = image_embeds
|
32 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
33 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
34 |
+
return clip_extra_context_tokens
|
35 |
+
|
36 |
+
|
37 |
+
class LyraIPAdapter:
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
sd_pipe,
|
41 |
+
sdxl,
|
42 |
+
device,
|
43 |
+
ip_ckpt=None,
|
44 |
+
ip_plus=False,
|
45 |
+
image_encoder_path=None,
|
46 |
+
num_ip_tokens=4,
|
47 |
+
ip_projection_dim=None,
|
48 |
+
fp_ckpt=None,
|
49 |
+
num_fp_tokens=1,
|
50 |
+
fp_projection_dim=None,
|
51 |
+
):
|
52 |
+
self.pipe = sd_pipe
|
53 |
+
self.device = device
|
54 |
+
self.fp_ckpt = fp_ckpt
|
55 |
+
self.ip_ckpt = ip_ckpt
|
56 |
+
self.num_fp_tokens = num_fp_tokens
|
57 |
+
self.num_ip_tokens = num_ip_tokens
|
58 |
+
self.fp_projection_dim = fp_projection_dim
|
59 |
+
self.ip_projection_dim = ip_projection_dim
|
60 |
+
self.sdxl = sdxl
|
61 |
+
self.ip_plus = ip_plus
|
62 |
+
self.cross_attention_dim = 2048
|
63 |
+
# self.pipe = sd_pipe.to(self.device)
|
64 |
+
# self.set_ip_adapter()
|
65 |
+
|
66 |
+
if image_encoder_path:
|
67 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(self.device, dtype=torch.float16)
|
68 |
+
self.clip_image_processor = CLIPImageProcessor()
|
69 |
+
self.projection_dim = self.image_encoder.config.projection_dim
|
70 |
+
|
71 |
+
# image proj model
|
72 |
+
if self.ip_ckpt:
|
73 |
+
if self.ip_plus:
|
74 |
+
proj_heads = 20 if self.sdxl else 12
|
75 |
+
self.image_proj_model = self.init_proj_plus(proj_heads, self.num_ip_tokens)
|
76 |
+
else:
|
77 |
+
self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
|
78 |
+
|
79 |
+
# face proj model
|
80 |
+
if self.fp_ckpt:
|
81 |
+
self.face_proj_model = self.init_proj(self.fp_projection_dim, self.num_fp_tokens)
|
82 |
+
|
83 |
+
self.load_ip_adapter()
|
84 |
+
|
85 |
+
def init_proj_diffuser(self, state_dict):
|
86 |
+
# diffusers加载版本
|
87 |
+
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
|
88 |
+
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
|
89 |
+
|
90 |
+
image_proj_model = ImageProjection(
|
91 |
+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
|
92 |
+
).to(dtype=self.dtype, device=self.device)
|
93 |
+
return image_proj_model
|
94 |
+
|
95 |
+
# init_proj / init_proj_plus 是 facein里实现的
|
96 |
+
def init_proj(self, projection_dim, num_tokens):
|
97 |
+
image_proj_model = ImageProjModel(
|
98 |
+
cross_attention_dim=self.cross_attention_dim,
|
99 |
+
clip_embeddings_dim=projection_dim,
|
100 |
+
clip_extra_context_tokens=num_tokens,
|
101 |
+
).to(self.device, dtype=torch.float16)
|
102 |
+
return image_proj_model
|
103 |
+
|
104 |
+
|
105 |
+
def init_proj_plus(self, heads, num_tokens):
|
106 |
+
image_proj_model = Resampler(
|
107 |
+
dim=1280,
|
108 |
+
depth=4,
|
109 |
+
dim_head=64,
|
110 |
+
heads=heads,
|
111 |
+
num_queries=num_tokens,
|
112 |
+
embedding_dim=self.image_encoder.config.hidden_size,
|
113 |
+
output_dim=self.cross_attention_dim,
|
114 |
+
ff_mult=4,
|
115 |
+
).to(self.device, dtype=torch.float16)
|
116 |
+
return image_proj_model
|
117 |
+
|
118 |
+
def load_ip_adapter(self):
|
119 |
+
unet = self.pipe.unet
|
120 |
+
|
121 |
+
def parse_ckpt_path(ckpt):
|
122 |
+
ll = ckpt.split("/")
|
123 |
+
weight_name = ll[-1]
|
124 |
+
subfolder = ll[-2]
|
125 |
+
pretrained_path = "/".join(ll[:-2])
|
126 |
+
return pretrained_path, subfolder, weight_name
|
127 |
+
|
128 |
+
if self.ip_ckpt:
|
129 |
+
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
130 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
131 |
+
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
|
132 |
+
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
133 |
+
unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
|
134 |
+
|
135 |
+
if self.fp_ckpt:
|
136 |
+
state_dict = torch.load(self.fp_ckpt, map_location="cpu")
|
137 |
+
self.face_proj_model.load_state_dict(state_dict["face_proj"])
|
138 |
+
pretrained_path, subfolder, weight_name = parse_ckpt_path(self.fp_ckpt)
|
139 |
+
dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
|
140 |
+
unet.load_facein(dir_ipadapter, "fp16")
|
141 |
+
|
142 |
+
@torch.inference_mode()
|
143 |
+
def get_image_embeds(self, image=None, face_emb=None):
|
144 |
+
image_prompt_embeds, uncond_image_prompt_embeds = None, None
|
145 |
+
|
146 |
+
if image is not None:
|
147 |
+
if not isinstance(image, list):
|
148 |
+
image = [image]
|
149 |
+
clip_image = self.clip_image_processor(images=image, return_tensors="pt").pixel_values
|
150 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
151 |
+
if self.ip_plus:
|
152 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
153 |
+
uncond_clip_image_embeds = self.image_encoder(
|
154 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
155 |
+
).hidden_states[-2]
|
156 |
+
else:
|
157 |
+
clip_image_embeds = self.image_encoder(clip_image).image_embeds
|
158 |
+
uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds)
|
159 |
+
clip_image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
160 |
+
uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
161 |
+
image_prompt_embeds = clip_image_prompt_embeds
|
162 |
+
uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
|
163 |
+
|
164 |
+
if face_emb is not None:
|
165 |
+
face_embeds = face_emb.to(self.device, dtype=torch.float16)
|
166 |
+
face_prompt_embeds = self.face_proj_model(face_embeds)
|
167 |
+
uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
|
168 |
+
if image_prompt_embeds is None:
|
169 |
+
image_prompt_embeds = face_prompt_embeds
|
170 |
+
uncond_image_prompt_embeds = uncond_face_prompt_embeds
|
171 |
+
else:
|
172 |
+
image_prompt_embeds = torch.cat([face_prompt_embeds, image_prompt_embeds], axis=1)
|
173 |
+
uncond_image_prompt_embeds = torch.cat([uncond_face_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
174 |
+
|
175 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
176 |
+
|
177 |
+
@torch.inference_mode()
|
178 |
+
def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None, face_emb=None, batch_size = 1, ip_scale=1.0, fp_scale=1.0, do_classifier_free_guidance=True):
|
179 |
+
dict_tensor = {}
|
180 |
+
|
181 |
+
if self.ip_ckpt and ip_scale>0:
|
182 |
+
if ip_image_embeds is not None:
|
183 |
+
dict_tensor["ip_hidden_states"] = ip_image_embeds
|
184 |
+
elif image is not None:
|
185 |
+
if not isinstance(image, list):
|
186 |
+
image = [image]
|
187 |
+
clip_image = self.clip_image_processor(images=image, return_tensors="pt").pixel_values
|
188 |
+
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
189 |
+
if self.ip_plus:
|
190 |
+
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
191 |
+
uncond_clip_image_embeds = self.image_encoder(
|
192 |
+
torch.zeros_like(clip_image), output_hidden_states=True
|
193 |
+
).hidden_states[-2]
|
194 |
+
else:
|
195 |
+
clip_image_embeds = self.image_encoder(clip_image).image_embeds
|
196 |
+
uncond_clip_image_embeds = torch.zeros_like(clip_image_embeds)
|
197 |
+
|
198 |
+
if do_classifier_free_guidance:
|
199 |
+
clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
|
200 |
+
ip_image_embeds = self.image_proj_model(clip_image_embeds)
|
201 |
+
dict_tensor["ip_hidden_states"] = ip_image_embeds
|
202 |
+
|
203 |
+
if face_emb is not None and self.fp_ckpt and ip_scale>0:
|
204 |
+
face_embeds = face_emb.to(self.device, dtype=torch.float16)
|
205 |
+
face_prompt_embeds = self.face_proj_model(face_embeds)
|
206 |
+
uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
|
207 |
+
if do_classifier_free_guidance:
|
208 |
+
fp_image_embeds = torch.cat([uncond_face_prompt_embeds, face_prompt_embeds])
|
209 |
+
else:
|
210 |
+
fp_image_embeds = face_prompt_embeds
|
211 |
+
dict_tensor["fp_hidden_states"] = fp_image_embeds
|
212 |
+
return dict_tensor
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
sys.path.append("/data/home/kiokaxiao/repos/LyraSD/python/lyrasd")
|
217 |
+
from lyrasd_model import LyraSdXLTxt2ImgPipeline
|
218 |
+
|
219 |
+
model_path = "/data/SharedModels/SD/checkpoints/stable-diffusion-xl-base-1.0/"
|
220 |
+
# model_path = "/cfs-datasets/projects/VirtualIdol/models/base_model/sdxl/xxmix9realisticsdxlV1"
|
221 |
+
lib_path = os.environ.get("LIBLYRASD_SO")
|
222 |
+
|
223 |
+
dir_ip_adapter = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
|
224 |
+
dir_facein = "/cfs-datasets/projects/VirtualIdol/models/FaceIn/v1/FaceIn_sdxl.bin"
|
225 |
+
image_encoder_path = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/models/image_encoder"
|
226 |
+
|
227 |
+
pipeline = LyraSdXLTxt2ImgPipeline(model_path, lib_path)
|
228 |
+
pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, dir_facein, 1, 512)
|
229 |
+
# pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, "", 1, 512)
|
230 |
+
|
231 |
+
face_emb = np.load("/data/home/kiokaxiao/repos/VidolImageDraw/girl.npy")
|
232 |
+
face_emb = torch.Tensor(face_emb.reshape([1,-1]))
|
233 |
+
ip_image = Image.open("/data/home/kiokaxiao/repos/VidolImageDraw/images/input_image.png").convert('RGB')
|
234 |
+
|
235 |
+
generator = torch.Generator("cuda").manual_seed(123)
|
236 |
+
batches = [2]
|
237 |
+
sizes = [[512, 512], [768, 768], [1024, 1024]]
|
238 |
+
# sizes = [[832, 640]]
|
239 |
+
# sizes = [[1024, 1024]]
|
240 |
+
running_cnt = 1
|
241 |
+
do_bench = False
|
242 |
+
|
243 |
+
ip_ratio = 1
|
244 |
+
facein_ratio = 0.6
|
245 |
+
extra_tensor_dict = {}
|
246 |
+
extra_tensor_dict = pipeline.ip_adapter_helper.get_image_embeds_lyrasd(ip_image, None, face_emb, batches[0], ip_ratio, facein_ratio)
|
247 |
+
param_scale_dict = {"facein_ratio": facein_ratio, "ip_ratio": ip_ratio}
|
248 |
+
draw_cfg = {'width': 640,
|
249 |
+
'num_inference_steps': 30,
|
250 |
+
'height': 832,
|
251 |
+
'negative_prompt': '(worst quality, low quality, 3d, 2d, cartoons, sketch), tooth, open mouth',
|
252 |
+
'guidance_scale': 7,
|
253 |
+
'prompt': 'xxmixgirl, masterpiece, best quality, 1girl, solo, looking at viewer, simple background, hair ornament, black eyes, portrait',
|
254 |
+
'output_type': 'pil',
|
255 |
+
'extra_tensor_dict': extra_tensor_dict,
|
256 |
+
"param_scale_dict": param_scale_dict}
|
257 |
+
|
258 |
+
|
259 |
+
def warmup(draw_cfg):
|
260 |
+
draw_cfg_wm = deepcopy(draw_cfg)
|
261 |
+
draw_cfg_wm['num_inference_steps'] = 1
|
262 |
+
pipeline(**draw_cfg_wm, generator= generator)
|
263 |
+
|
264 |
+
if not do_bench:
|
265 |
+
images = pipeline(**draw_cfg, generator= generator)
|
266 |
+
else:
|
267 |
+
for batch in batches:
|
268 |
+
for height, width in sizes:
|
269 |
+
draw_cfg['width'] = width
|
270 |
+
draw_cfg['height'] = height
|
271 |
+
draw_cfg['num_images_per_prompt'] = batch
|
272 |
+
draw_cfg["num_inference_steps"] = 20
|
273 |
+
warmup(draw_cfg)
|
274 |
+
time_uses = []
|
275 |
+
for x in range(running_cnt):
|
276 |
+
start = time.perf_counter()
|
277 |
+
draw_cfg['num_images_per_prompt'] = batch
|
278 |
+
generator = torch.Generator("cuda").manual_seed(123)
|
279 |
+
print("draw_cfg: ", draw_cfg.keys())
|
280 |
+
print("draw_cfg: ", draw_cfg)
|
281 |
+
|
282 |
+
images = pipeline(**draw_cfg, generator= generator)
|
283 |
+
time_use = time.perf_counter() - start
|
284 |
+
time_uses.append(time_use)
|
285 |
+
print("bench", batch, width, sum(time_uses)/running_cnt, get_mem_use())
|
286 |
+
|
287 |
+
print(type(images))
|
288 |
+
images[0].save("t.png")
|
289 |
+
|
lyrasd_model/module/resampler.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
# FFN
|
9 |
+
def FeedForward(dim, mult=4):
|
10 |
+
inner_dim = int(dim * mult)
|
11 |
+
return nn.Sequential(
|
12 |
+
nn.LayerNorm(dim),
|
13 |
+
nn.Linear(dim, inner_dim, bias=False),
|
14 |
+
nn.GELU(),
|
15 |
+
nn.Linear(inner_dim, dim, bias=False),
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def reshape_tensor(x, heads):
|
20 |
+
bs, length, width = x.shape
|
21 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
22 |
+
x = x.view(bs, length, heads, -1)
|
23 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
24 |
+
x = x.transpose(1, 2)
|
25 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
26 |
+
x = x.reshape(bs, heads, length, -1)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class PerceiverAttention(nn.Module):
|
31 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
32 |
+
super().__init__()
|
33 |
+
self.scale = dim_head**-0.5
|
34 |
+
self.dim_head = dim_head
|
35 |
+
self.heads = heads
|
36 |
+
inner_dim = dim_head * heads
|
37 |
+
|
38 |
+
self.norm1 = nn.LayerNorm(dim)
|
39 |
+
self.norm2 = nn.LayerNorm(dim)
|
40 |
+
|
41 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
42 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
43 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, x, latents):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
x (torch.Tensor): image features
|
50 |
+
shape (b, n1, D)
|
51 |
+
latent (torch.Tensor): latent features
|
52 |
+
shape (b, n2, D)
|
53 |
+
"""
|
54 |
+
x = self.norm1(x)
|
55 |
+
latents = self.norm2(latents)
|
56 |
+
|
57 |
+
b, l, _ = latents.shape
|
58 |
+
|
59 |
+
q = self.to_q(latents)
|
60 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
61 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
62 |
+
|
63 |
+
q = reshape_tensor(q, self.heads)
|
64 |
+
k = reshape_tensor(k, self.heads)
|
65 |
+
v = reshape_tensor(v, self.heads)
|
66 |
+
|
67 |
+
# attention
|
68 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
69 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
70 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
71 |
+
out = weight @ v
|
72 |
+
|
73 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
74 |
+
|
75 |
+
return self.to_out(out)
|
76 |
+
|
77 |
+
|
78 |
+
class Resampler(nn.Module):
|
79 |
+
def __init__(
|
80 |
+
self,
|
81 |
+
dim=1024,
|
82 |
+
depth=8,
|
83 |
+
dim_head=64,
|
84 |
+
heads=16,
|
85 |
+
num_queries=8,
|
86 |
+
embedding_dim=768,
|
87 |
+
output_dim=1024,
|
88 |
+
ff_mult=4,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
93 |
+
|
94 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
95 |
+
|
96 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
97 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
98 |
+
|
99 |
+
self.layers = nn.ModuleList([])
|
100 |
+
for _ in range(depth):
|
101 |
+
self.layers.append(
|
102 |
+
nn.ModuleList(
|
103 |
+
[
|
104 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
105 |
+
FeedForward(dim=dim, mult=ff_mult),
|
106 |
+
]
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
|
112 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
113 |
+
|
114 |
+
x = self.proj_in(x)
|
115 |
+
print("layers: ", len(self.layers))
|
116 |
+
for attn, ff in self.layers:
|
117 |
+
latents = attn(x, latents) + latents
|
118 |
+
latents = ff(latents) + latents
|
119 |
+
|
120 |
+
latents = self.proj_out(latents)
|
121 |
+
return self.norm_out(latents)
|
lyrasd_model/module/tools.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import os, sys
|
4 |
+
import time
|
5 |
+
|
6 |
+
class LyraChecker:
|
7 |
+
def __init__(self, dir_data, tol):
|
8 |
+
self.dir_data = dir_data
|
9 |
+
self.tol = tol
|
10 |
+
|
11 |
+
def cmp(self, fpath1, fpath2="", tol=0):
|
12 |
+
tolbk = self.tol
|
13 |
+
if tol != 0:
|
14 |
+
self.tol = tol
|
15 |
+
if fpath2 == "":
|
16 |
+
fpath2 = fpath1
|
17 |
+
fpath1 += "_1"
|
18 |
+
fpath2 += "_2"
|
19 |
+
v1 = self.get_npy(fpath1) #np.load(os.path.join(self.dir_data, fpath1))
|
20 |
+
v2 = self.get_npy(fpath2) #np.load(os.path.join(self.dir_data, fpath2))
|
21 |
+
name = fpath1
|
22 |
+
if ".npy" in fpath1:
|
23 |
+
name = ".".join(os.path.basename(fpath1).split(".")[:-1])
|
24 |
+
self._cmp_inner(v1, v2, name)
|
25 |
+
self.tol = tolbk
|
26 |
+
|
27 |
+
def _cmp_inner(self, v1, v2, name):
|
28 |
+
print(v1.shape, v2.shape)
|
29 |
+
if v1.shape != v2.shape:
|
30 |
+
if v1.shape[1] == v2.shape[1]:
|
31 |
+
v2 = v2.reshape([v2.shape[0], v2.shape[1], -1])
|
32 |
+
else:
|
33 |
+
v2 = torch.tensor(v2).permute(0, 3, 1, 2).numpy()
|
34 |
+
print(v1.shape, v2.shape)
|
35 |
+
self._check_data(name, v1, v2)
|
36 |
+
print(np.size(v1))
|
37 |
+
|
38 |
+
def _check_data(self, stage, x_out, x_gt):
|
39 |
+
print(f"========== {stage} =============")
|
40 |
+
print(x_out.shape, x_gt.shape)
|
41 |
+
if np.allclose(x_gt, x_out, atol=self.tol):
|
42 |
+
print(f"[OK] At {stage}, tol: {self.tol}")
|
43 |
+
else:
|
44 |
+
diff_cnt = np.count_nonzero(np.abs(x_gt - x_out)>self.tol)
|
45 |
+
print(f"[FAIL]At {stage}, not aligned. tol: {self.tol}")
|
46 |
+
print(" [INFO]Max diff: ", np.max(np.abs(x_gt - x_out)))
|
47 |
+
print(" [INFO]Diff count: ", diff_cnt, ", ratio: ", round(diff_cnt/np.size(x_out), 2))
|
48 |
+
print(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
|
49 |
+
|
50 |
+
|
51 |
+
def cmp_query(self, fpath1, fpath2):
|
52 |
+
v1 = np.load(os.path.join(self.dir_data, fpath1))
|
53 |
+
vk = np.load(os.path.join(self.dir_data, fpath1).replace("query", "key"))
|
54 |
+
vv = np.load(os.path.join(self.dir_data, fpath1).replace("query", "value"))
|
55 |
+
|
56 |
+
v2 = np.load(os.path.join(self.dir_data, fpath2))
|
57 |
+
# print(v1.shape, v2.shape)
|
58 |
+
q2 = v2[:,:,0,:,:].transpose([0,2,1,3])
|
59 |
+
# print(v1.shape, q2.shape)
|
60 |
+
self.check_data("query", v1, q2)
|
61 |
+
# print(vk.shape, v2.shape)
|
62 |
+
k2 = v2[:,:,1,:,:].transpose([0,2,1,3])
|
63 |
+
self.check_data("key", vk, k2)
|
64 |
+
vv2 = v2[:,:,2,:,:].transpose([0,2,1,3])
|
65 |
+
# print(vv.shape, vv2.shape)
|
66 |
+
self.check_data("value", vv, vv2)
|
67 |
+
|
68 |
+
def _get_data_fpath(self, fname):
|
69 |
+
fpath = os.path.join(self.dir_data, fname)
|
70 |
+
if not fpath.endswith(".npy"):
|
71 |
+
fpath += ".npy"
|
72 |
+
return fpath
|
73 |
+
|
74 |
+
def get_npy(self, fname):
|
75 |
+
fpath = self._get_data_fpath(fname)
|
76 |
+
return np.load(fpath)
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
class MkDataHelper:
|
82 |
+
def __init__(self, data_dir="/data/home/kiokaxiao/data"):
|
83 |
+
self.data_dir = data_dir
|
84 |
+
|
85 |
+
def mkdata(self, subdir, name, shape, dtype=torch.float16):
|
86 |
+
outdir = os.path.join(self.data_dir, subdir)
|
87 |
+
os.makedirs(outdir, exist_ok=True)
|
88 |
+
fpath = os.path.join(outdir, name+".npy")
|
89 |
+
data = torch.randn(shape, dtype=torch.float16)
|
90 |
+
np.save(fpath, data.to(dtype).numpy())
|
91 |
+
return data
|
92 |
+
|
93 |
+
def gen_out_with_func(self, func, inputs):
|
94 |
+
output = func(inputs)
|
95 |
+
return output
|
96 |
+
|
97 |
+
def savedata(self, subdir, name, data):
|
98 |
+
outdir = os.path.join(self.data_dir, subdir)
|
99 |
+
os.makedirs(outdir, exist_ok=True)
|
100 |
+
fpath = os.path.join(outdir, name+".npy")
|
101 |
+
np.save(fpath, data.cpu().numpy())
|
102 |
+
|
103 |
+
|
104 |
+
class TorchSaver:
|
105 |
+
def __init__(self, data_dir):
|
106 |
+
self.data_dir = data_dir
|
107 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
108 |
+
self.is_save = True
|
109 |
+
|
110 |
+
def save_v(self, name, v):
|
111 |
+
if not self.is_save:
|
112 |
+
return
|
113 |
+
fpath = os.path.join(self.data_dir, name+"_1.npy")
|
114 |
+
np.save(fpath, v.detach().cpu().numpy())
|
115 |
+
|
116 |
+
def save_v2(self, name, v):
|
117 |
+
if not self.is_save:
|
118 |
+
return
|
119 |
+
fpath = os.path.join(self.data_dir, name+"_1.npy")
|
120 |
+
np.save(fpath, v.detach().cpu().numpy())
|
121 |
+
|
122 |
+
def timer_annoc(funct):
|
123 |
+
def inner(*args,**kwargs):
|
124 |
+
start = time.perf_counter()
|
125 |
+
res = funct(*args,**kwargs)
|
126 |
+
torch.cuda.synchronize()
|
127 |
+
end = time.perf_counter()
|
128 |
+
print("torch cost: ", end-start)
|
129 |
+
return res
|
130 |
+
return inner
|
131 |
+
|
132 |
+
def get_mem_use():
|
133 |
+
f = os.popen("nvidia-smi | grep MiB" )
|
134 |
+
line = f.read().strip()
|
135 |
+
while " " in line:
|
136 |
+
line = line.replace(" ", " ")
|
137 |
+
memuse = line.split(" ")[8]
|
138 |
+
return memuse
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
dir_data = sys.argv[1]
|
142 |
+
fname_v1 = sys.argv[2]
|
143 |
+
fname_v2 = sys.argv[3]
|
144 |
+
tol = 0.01
|
145 |
+
if len(sys.argv) > 4:
|
146 |
+
tol = float(sys.argv[4])
|
147 |
+
checker = LyraChecker(dir_data, tol)
|
148 |
+
checker.cmp(fname_v1, fname_v2)
|
models/README.md
CHANGED
@@ -2,11 +2,20 @@
|
|
2 |
### This is the place where you should download the checkpoints, and unzip them
|
3 |
|
4 |
```bash
|
5 |
-
wget -O lyrasd_rev_animated.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/lyrasd_rev_animated.tar.gz
|
6 |
-
|
7 |
-
wget -O
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
tar -xvf lyrasd_rev_animated.tar.gz
|
10 |
-
tar -xvf
|
11 |
-
tar -xvf
|
|
|
12 |
```
|
|
|
2 |
### This is the place where you should download the checkpoints, and unzip them
|
3 |
|
4 |
```bash
|
5 |
+
wget -O lyrasd_rev_animated.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/lyrasd_rev_animated.tar.gz"
|
6 |
+
|
7 |
+
wget -O sd-controlnet-canny.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/sd-controlnet-canny.tar.gz"
|
8 |
+
|
9 |
+
wget -O xiaorenshu.safetensors "https://civitai.com/api/download/models/25661"
|
10 |
+
|
11 |
+
wget -O helloworldSDXL20Fp16.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/helloworldSDXL20Fp16.tar.gz"
|
12 |
+
|
13 |
+
wget -O controlnet-canny-sdxl-1.0.tar.gz "https://chuangxin-research-1258344705.cos.ap-guangzhou.myqcloud.com/share/files/lyrasd/controlnet-canny-sdxl-1.0.tar.gz"
|
14 |
+
|
15 |
+
wget -O dissolve_sdxl.safetensors "https://civitai.com/api/download/models/277389?type=Model&format=SafeTensor"
|
16 |
|
17 |
tar -xvf lyrasd_rev_animated.tar.gz
|
18 |
+
tar -xvf sd-controlnet-canny.tar.gz
|
19 |
+
tar -xvf helloworldSDXL20Fp16.tar.gz
|
20 |
+
tar -xvf controlnet-canny-sdxl-1.0.tar.gz
|
21 |
```
|
outputs/res_controlnet_img2img_0.png
CHANGED
Git LFS Details
|
Git LFS Details
|
outputs/{res_controlnet_sdxl_txt2img.png → res_controlnet_sdxl_txt2img_0.png}
RENAMED
File without changes
|
outputs/res_controlnet_txt2img_0.png
CHANGED
Git LFS Details
|
Git LFS Details
|
outputs/res_img2img_0.png
CHANGED
Git LFS Details
|
Git LFS Details
|
outputs/res_txt2img_lora_0.png
CHANGED
Git LFS Details
|
Git LFS Details
|
outputs/{res_sdxl_txt2img_lora_0.png → res_txt2img_xl_lora_0.png}
RENAMED
File without changes
|
txt2img_demo.py
CHANGED
@@ -10,22 +10,25 @@ from lyrasd_model import LyraSdTxt2ImgPipeline
|
|
10 |
# 4. scheduler 配置
|
11 |
|
12 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
13 |
-
lib_path = "./lyrasd_model/lyrasd_lib/
|
14 |
-
model_path = "./models/
|
15 |
-
lora_path = "./models/
|
|
|
|
|
16 |
|
17 |
# 构建 Txt2Img 的 Pipeline
|
18 |
-
model = LyraSdTxt2ImgPipeline(
|
|
|
19 |
|
20 |
# load lora
|
21 |
# 参数分别为 lora 存放位置,名字,lora 强度,lora模型精度
|
22 |
-
model.
|
23 |
|
24 |
# 准备应用的输入和超参数
|
25 |
prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
|
26 |
negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
|
27 |
height, width = 512, 512
|
28 |
-
steps =
|
29 |
guidance_scale = 7
|
30 |
generator = torch.Generator().manual_seed(123)
|
31 |
num_images = 1
|
@@ -33,12 +36,12 @@ num_images = 1
|
|
33 |
start = time.perf_counter()
|
34 |
# 推理生成
|
35 |
images = model(prompt, height, width, steps,
|
36 |
-
|
37 |
-
|
38 |
-
print("image gen cost: ",time.perf_counter() - start)
|
39 |
# 存储生成的图片
|
40 |
for i, image in enumerate(images):
|
41 |
image.save(f"outputs/res_txt2img_lora_{i}.png")
|
42 |
|
43 |
# unload lora,参数为 lora 的名字,是否清除 lora 缓存
|
44 |
-
|
|
|
10 |
# 4. scheduler 配置
|
11 |
|
12 |
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
13 |
+
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
|
14 |
+
model_path = "./models/rev-animated"
|
15 |
+
lora_path = "./models/xiaorenshu.safetensors"
|
16 |
+
|
17 |
+
torch.classes.load_library(lib_path)
|
18 |
|
19 |
# 构建 Txt2Img 的 Pipeline
|
20 |
+
model = LyraSdTxt2ImgPipeline()
|
21 |
+
model.reload_pipe(model_path)
|
22 |
|
23 |
# load lora
|
24 |
# 参数分别为 lora 存放位置,名字,lora 强度,lora模型精度
|
25 |
+
model.load_lora_v2(lora_path, "xiaorenshu", 0.4)
|
26 |
|
27 |
# 准备应用的输入和超参数
|
28 |
prompt = "a cat, cute, cartoon, concise, traditional, chinese painting, Tang and Song Dynasties, masterpiece, 4k, 8k, UHD, best quality"
|
29 |
negative_prompt = "(((horrible))), (((scary))), (((naked))), (((large breasts))), high saturation, colorful, human:2, body:2, low quality, bad quality, lowres, out of frame, duplicate, watermark, signature, text, frames, cut, cropped, malformed limbs, extra limbs, (((missing arms))), (((missing legs)))"
|
30 |
height, width = 512, 512
|
31 |
+
steps = 20
|
32 |
guidance_scale = 7
|
33 |
generator = torch.Generator().manual_seed(123)
|
34 |
num_images = 1
|
|
|
36 |
start = time.perf_counter()
|
37 |
# 推理生成
|
38 |
images = model(prompt, height, width, steps,
|
39 |
+
guidance_scale, negative_prompt, num_images,
|
40 |
+
generator=generator)
|
41 |
+
print("image gen cost: ", time.perf_counter() - start)
|
42 |
# 存储生成的图片
|
43 |
for i, image in enumerate(images):
|
44 |
image.save(f"outputs/res_txt2img_lora_{i}.png")
|
45 |
|
46 |
# unload lora,参数为 lora 的名字,是否清除 lora 缓存
|
47 |
+
model.unload_lora_v2("xiaorenshu", True)
|
txt2img_sdxl_demo.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from lyrasd_model import LyraSdXLTxt2ImgPipeline
|
3 |
+
import time
|
4 |
+
import GPUtil
|
5 |
+
import os
|
6 |
+
from glob import glob
|
7 |
+
import random
|
8 |
+
|
9 |
+
# 存放模型文件的路径,应该包含一下结构:
|
10 |
+
# 1. clip 模型
|
11 |
+
# 2. 转换好的优化后的 unet 模型,放入其中的 unet_bins 文件夹
|
12 |
+
# 3. vae 模型
|
13 |
+
# 4. scheduler 配置
|
14 |
+
|
15 |
+
# LyraSD 的 C++ 编译动态链接库,其中包含 C++ CUDA 计算的细节
|
16 |
+
lib_path = "./lyrasd_model/lyrasd_lib/libth_lyrasd_cu12_sm80.so"
|
17 |
+
model_path = "./models/helloworldSDXL20Fp16"
|
18 |
+
lora_path = "./models/dissolve_sdxl.safetensors"
|
19 |
+
torch.classes.load_library(lib_path)
|
20 |
+
|
21 |
+
# 构建 Txt2Img 的 Pipeline
|
22 |
+
model = LyraSdXLTxt2ImgPipeline()
|
23 |
+
|
24 |
+
model.reload_pipe(model_path)
|
25 |
+
|
26 |
+
# load lora
|
27 |
+
# lora model path, name,lora strength
|
28 |
+
model.load_lora_v2(lora_path, "dissolve_sdxl", 0.4)
|
29 |
+
|
30 |
+
# 准备应用的输入和超参数
|
31 |
+
prompt = "a cat, ral-dissolve"
|
32 |
+
negative_prompt = "nswf, watermark"
|
33 |
+
height, width = 1024, 1024
|
34 |
+
steps = 20
|
35 |
+
guidance_scale = 7.5
|
36 |
+
generator = torch.Generator().manual_seed(8788800)
|
37 |
+
|
38 |
+
start = time.perf_counter()
|
39 |
+
# 推理生成
|
40 |
+
images = model(prompt,
|
41 |
+
height=height,
|
42 |
+
width=width,
|
43 |
+
num_inference_steps=steps,
|
44 |
+
num_images_per_prompt=1,
|
45 |
+
guidance_scale=guidance_scale,
|
46 |
+
negative_prompt=negative_prompt,
|
47 |
+
generator=generator
|
48 |
+
)
|
49 |
+
print("image gen cost: ", time.perf_counter() - start)
|
50 |
+
# 存储生成的图片
|
51 |
+
for i, image in enumerate(images):
|
52 |
+
image.save(f"outputs/res_txt2img_xl_lora_{i}.png")
|
53 |
+
|
54 |
+
# unload lora,参数为 lora 的名字,是否清除 lora 缓存
|
55 |
+
model.unload_lora_v2("dissolve_sdxl", True)
|