aka7774 commited on
Commit
a366424
1 Parent(s): fa559b1

Use "cpu" in merge, save model

Browse files
Files changed (1) hide show
  1. evosdxl_jp_v1.py +9 -4
evosdxl_jp_v1.py CHANGED
@@ -15,7 +15,7 @@ from diffusers.loaders import LoraLoaderMixin
15
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
16
  JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
17
  L_REPO = "ByteDance/SDXL-Lightning"
18
-
19
 
20
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
21
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
@@ -104,7 +104,7 @@ def split_conv_attn(weights):
104
  return {"conv": conv_tensors, "attn": attn_tensors}
105
 
106
 
107
- def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
108
  sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
109
  dpo_weights = split_conv_attn(
110
  load_from_pretrained(
@@ -172,9 +172,12 @@ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
172
  [0.023119324530758375, 0.04924981616469831, 0.9276308593045434],
173
  )
174
  new_weights = {**new_conv, **new_attn}
175
- unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
176
- unet.load_state_dict({**new_conv, **new_attn})
177
 
 
 
 
 
178
  text_encoder = CLIPTextModelWithProjection.from_pretrained(
179
  JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16"
180
  )
@@ -199,6 +202,8 @@ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
199
 
200
 
201
  if __name__ == "__main__":
 
 
202
  pipe: StableDiffusionXLPipeline = load_evosdxl_jp()
203
  images = pipe("犬", num_inference_steps=4, guidance_scale=0).images
204
  images[0].save("out.png")
 
15
  SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
16
  JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
17
  L_REPO = "ByteDance/SDXL-Lightning"
18
+ MERGED_FILE = "evosdxl_jp_v1.safetensors"
19
 
20
  def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
21
  file_extension = os.path.basename(checkpoint_file).split(".")[-1]
 
104
  return {"conv": conv_tensors, "attn": attn_tensors}
105
 
106
 
107
+ def merge_evosdxl_jp(device="cpu") -> StableDiffusionXLPipeline:
108
  sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
109
  dpo_weights = split_conv_attn(
110
  load_from_pretrained(
 
172
  [0.023119324530758375, 0.04924981616469831, 0.9276308593045434],
173
  )
174
  new_weights = {**new_conv, **new_attn}
175
+ safetensors.torch.save_file(new_weights, MERGED_FILE)
 
176
 
177
+ def load_evosdxl_jp(device="cuda"):
178
+ unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
179
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
180
+ unet.load_state_dict(safetensors.torch.load_file(MERGED_FILE))
181
  text_encoder = CLIPTextModelWithProjection.from_pretrained(
182
  JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16"
183
  )
 
202
 
203
 
204
  if __name__ == "__main__":
205
+ if not os.path.exists(MERGED_FILE):
206
+ merge_evosdxl_jp()
207
  pipe: StableDiffusionXLPipeline = load_evosdxl_jp()
208
  images = pipe("犬", num_inference_steps=4, guidance_scale=0).images
209
  images[0].save("out.png")