Use "cpu" in merge, save model
Browse files- evosdxl_jp_v1.py +9 -4
evosdxl_jp_v1.py
CHANGED
@@ -15,7 +15,7 @@ from diffusers.loaders import LoraLoaderMixin
|
|
15 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
16 |
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
|
17 |
L_REPO = "ByteDance/SDXL-Lightning"
|
18 |
-
|
19 |
|
20 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
21 |
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
@@ -104,7 +104,7 @@ def split_conv_attn(weights):
|
|
104 |
return {"conv": conv_tensors, "attn": attn_tensors}
|
105 |
|
106 |
|
107 |
-
def
|
108 |
sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
|
109 |
dpo_weights = split_conv_attn(
|
110 |
load_from_pretrained(
|
@@ -172,9 +172,12 @@ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
|
|
172 |
[0.023119324530758375, 0.04924981616469831, 0.9276308593045434],
|
173 |
)
|
174 |
new_weights = {**new_conv, **new_attn}
|
175 |
-
|
176 |
-
unet.load_state_dict({**new_conv, **new_attn})
|
177 |
|
|
|
|
|
|
|
|
|
178 |
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
179 |
JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16"
|
180 |
)
|
@@ -199,6 +202,8 @@ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
|
|
199 |
|
200 |
|
201 |
if __name__ == "__main__":
|
|
|
|
|
202 |
pipe: StableDiffusionXLPipeline = load_evosdxl_jp()
|
203 |
images = pipe("犬", num_inference_steps=4, guidance_scale=0).images
|
204 |
images[0].save("out.png")
|
|
|
15 |
SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
|
16 |
JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
|
17 |
L_REPO = "ByteDance/SDXL-Lightning"
|
18 |
+
MERGED_FILE = "evosdxl_jp_v1.safetensors"
|
19 |
|
20 |
def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
|
21 |
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
|
|
104 |
return {"conv": conv_tensors, "attn": attn_tensors}
|
105 |
|
106 |
|
107 |
+
def merge_evosdxl_jp(device="cpu") -> StableDiffusionXLPipeline:
|
108 |
sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
|
109 |
dpo_weights = split_conv_attn(
|
110 |
load_from_pretrained(
|
|
|
172 |
[0.023119324530758375, 0.04924981616469831, 0.9276308593045434],
|
173 |
)
|
174 |
new_weights = {**new_conv, **new_attn}
|
175 |
+
safetensors.torch.save_file(new_weights, MERGED_FILE)
|
|
|
176 |
|
177 |
+
def load_evosdxl_jp(device="cuda"):
|
178 |
+
unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
|
179 |
+
unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
|
180 |
+
unet.load_state_dict(safetensors.torch.load_file(MERGED_FILE))
|
181 |
text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
182 |
JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16"
|
183 |
)
|
|
|
202 |
|
203 |
|
204 |
if __name__ == "__main__":
|
205 |
+
if not os.path.exists(MERGED_FILE):
|
206 |
+
merge_evosdxl_jp()
|
207 |
pipe: StableDiffusionXLPipeline = load_evosdxl_jp()
|
208 |
images = pipe("犬", num_inference_steps=4, guidance_scale=0).images
|
209 |
images[0].save("out.png")
|