mkshing yuki-imajuku commited on
Commit
082e32a
·
verified ·
1 Parent(s): f603dd5

Create evo_nishikie_v1.py (#1)

Browse files

- Create evo_nishikie_v1.py (ca25b1a59f17fac66a9ca3481fee12d97e6b0594)


Co-authored-by: Yuki Imajuku <yuki-imajuku@users.noreply.huggingface.co>

Files changed (1) hide show
  1. evo_nishikie_v1.py +206 -0
evo_nishikie_v1.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from io import BytesIO
3
+ import os
4
+ from typing import Dict, List, Union
5
+
6
+ from PIL import Image
7
+ from controlnet_aux import CannyDetector
8
+ from diffusers import (
9
+ ControlNetModel,
10
+ StableDiffusionXLControlNetPipeline,
11
+ UNet2DConditionModel,
12
+ )
13
+ from huggingface_hub import hf_hub_download
14
+ import requests
15
+ import safetensors
16
+ import torch
17
+ from tqdm import tqdm
18
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection
19
+
20
+ # Base models (fine-tuned from SDXL-1.0)
21
+ SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
22
+ DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
23
+ JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
24
+ JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
25
+
26
+ # Evo-Ukiyoe
27
+ UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
28
+
29
+ # Evo-Nishikie
30
+ NISHIKIE_REPO = "SakanaAI/Evo-Nishikie-v1"
31
+
32
+
33
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
34
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
35
+ if file_extension == "safetensors":
36
+ return safetensors.torch.load_file(checkpoint_file, device=device)
37
+ else:
38
+ return torch.load(checkpoint_file, map_location=device)
39
+
40
+
41
+ def load_from_pretrained(
42
+ repo_id,
43
+ filename="diffusion_pytorch_model.fp16.safetensors",
44
+ subfolder="unet",
45
+ device="cuda",
46
+ ) -> Dict[str, torch.Tensor]:
47
+ return load_state_dict(
48
+ hf_hub_download(
49
+ repo_id=repo_id,
50
+ filename=filename,
51
+ subfolder=subfolder,
52
+ ),
53
+ device=device,
54
+ )
55
+
56
+
57
+ def reshape_weight_task_tensors(task_tensors, weights):
58
+ """
59
+ Reshapes `weights` to match the shape of `task_tensors` by unsqueezing in the remaining dimensions.
60
+
61
+ Args:
62
+ task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`.
63
+ weights (`torch.Tensor`): The tensor to be reshaped.
64
+
65
+ Returns:
66
+ `torch.Tensor`: The reshaped tensor.
67
+ """
68
+ new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
69
+ weights = weights.view(new_shape)
70
+ return weights
71
+
72
+
73
+ def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
74
+ """
75
+ Merge the task tensors using `linear`.
76
+
77
+ Args:
78
+ task_tensors(`List[torch.Tensor]`):The task tensors to merge.
79
+ weights (`torch.Tensor`):The weights of the task tensors.
80
+
81
+ Returns:
82
+ `torch.Tensor`: The merged tensor.
83
+ """
84
+ task_tensors = torch.stack(task_tensors, dim=0)
85
+ # weighted task tensors
86
+ weights = reshape_weight_task_tensors(task_tensors, weights)
87
+ weighted_task_tensors = task_tensors * weights
88
+ mixed_task_tensors = weighted_task_tensors.sum(dim=0)
89
+ return mixed_task_tensors
90
+
91
+
92
+ def merge_models(task_tensors, weights):
93
+ keys = list(task_tensors[0].keys())
94
+ weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device)
95
+ state_dict = {}
96
+ for key in tqdm(keys, desc="Merging"):
97
+ w_list = []
98
+ for i, sd in enumerate(task_tensors):
99
+ w = sd.pop(key)
100
+ w_list.append(w)
101
+ new_w = linear(task_tensors=w_list, weights=weights)
102
+ state_dict[key] = new_w
103
+ return state_dict
104
+
105
+
106
+ def split_conv_attn(weights):
107
+ attn_tensors = {}
108
+ conv_tensors = {}
109
+ for key in list(weights.keys()):
110
+ if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]):
111
+ attn_tensors[key] = weights.pop(key)
112
+ else:
113
+ conv_tensors[key] = weights.pop(key)
114
+ return {"conv": conv_tensors, "attn": attn_tensors}
115
+
116
+
117
+ def load_evo_nishikie(device="cuda") -> StableDiffusionXLControlNetPipeline:
118
+ # Load base models
119
+ sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
120
+ dpo_weights = split_conv_attn(
121
+ load_from_pretrained(
122
+ DPO_REPO, "diffusion_pytorch_model.safetensors", device=device
123
+ )
124
+ )
125
+ jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
126
+ jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
127
+ # Merge base models
128
+ tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
129
+ new_conv = merge_models(
130
+ [sd["conv"] for sd in tensors],
131
+ [
132
+ 0.15928833971605916,
133
+ 0.1032449268871776,
134
+ 0.6503217149752791,
135
+ 0.08714501842148402,
136
+ ],
137
+ )
138
+ new_attn = merge_models(
139
+ [sd["attn"] for sd in tensors],
140
+ [
141
+ 0.1877279276437178,
142
+ 0.20014114603909822,
143
+ 0.3922685507065275,
144
+ 0.2198623756106564,
145
+ ],
146
+ )
147
+ del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
148
+ gc.collect()
149
+ if "cuda" in device:
150
+ torch.cuda.empty_cache()
151
+
152
+ unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
153
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
154
+ unet.load_state_dict({**new_conv, **new_attn})
155
+
156
+ # Load other modules
157
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
158
+ JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16",
159
+ )
160
+ tokenizer = AutoTokenizer.from_pretrained(
161
+ JSDXL_REPO, subfolder="tokenizer", use_fast=False,
162
+ )
163
+
164
+ # Load Evo-Nishikie weights
165
+ controlnet = ControlNetModel.from_pretrained(
166
+ NISHIKIE_REPO, torch_dtype=torch.float16, device=device,
167
+ )
168
+
169
+ # Load pipeline
170
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
171
+ SDXL_REPO,
172
+ unet=unet,
173
+ text_encoder=text_encoder,
174
+ tokenizer=tokenizer,
175
+ controlnet=controlnet,
176
+ torch_dtype=torch.float16,
177
+ variant="fp16",
178
+ )
179
+ pipe = pipe.to(device, dtype=torch.float16)
180
+
181
+ # Load Evo-Ukiyoe weights
182
+ pipe.load_lora_weights(UKIYOE_REPO)
183
+ pipe.fuse_lora(lora_scale=1.0)
184
+ return pipe
185
+
186
+
187
+ if __name__ == "__main__":
188
+ url = "https://sakana.ai/assets/nedo-grant/nedo_grant.jpeg"
189
+ original_image = Image.open(
190
+ BytesIO(requests.get(url).content)
191
+ ).resize((1024, 1024), Image.Resampling.LANCZOS)
192
+ canny_detector = CannyDetector()
193
+ canny_image = canny_detector(original_image, image_resolution=1024)
194
+ pipe: StableDiffusionXLControlNetPipeline = load_evo_nishikie()
195
+ images = pipe(
196
+ prompt="銀杏が色づく。草木が生えた地面と青空の富士山。最高品質の輻の浮世絵。",
197
+ negative_prompt="暗い。",
198
+ image=canny_image,
199
+ guidance_scale=8.0,
200
+ controlnet_conditioning_scale=0.6,
201
+ num_inference_steps=50,
202
+ generator=torch.Generator().manual_seed(0),
203
+ num_images_per_prompt=1,
204
+ output_type="pil",
205
+ ).images
206
+ images[0].save("out.png")