Zhiminli commited on
Commit
59b294b
1 Parent(s): 33a2f69

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +58 -0
README.md CHANGED
@@ -174,4 +174,62 @@ python sample_t2i.py --infer-mode fa --prompt "青花瓷风格,一只猫在追
174
  python sample_t2i.py --prompt "青花瓷风格,一只猫在追蝴蝶" --image-size 1280 768 --load-key ema --lora_ckpt ./ckpts/t2i/lora/porcelain
175
  ```
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  More example prompts can be found in [example_prompts.txt](example_prompts.txt)
 
174
  python sample_t2i.py --prompt "青花瓷风格,一只猫在追蝴蝶" --image-size 1280 768 --load-key ema --lora_ckpt ./ckpts/t2i/lora/porcelain
175
  ```
176
 
177
+
178
+ Regarding how to use the LoRA weights we trained in diffusion, we provide the following script. To ensure compatibility with the diffuser, some modifications are made, which means that LoRA cannot be directly loaded.
179
+
180
+ ```python
181
+ import torch
182
+ from diffusers import HunyuanDiTPipeline
183
+
184
+ num_layers = 40
185
+ def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
186
+ for i in range(num_layers):
187
+ Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"])
188
+ q, k, v = torch.chunk(Wqkv, 3, dim=0)
189
+ transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q
190
+ transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
191
+ transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v
192
+
193
+ out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"])
194
+ transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj
195
+
196
+ q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
197
+ transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj
198
+
199
+ kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
200
+ k, v = torch.chunk(kv_proj, 2, dim=0)
201
+ transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
202
+ transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v
203
+
204
+ out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"])
205
+ transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj
206
+
207
+ q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
208
+ transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj
209
+
210
+ return transformer_state_dict
211
+
212
+ pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", torch_dtype=torch.float16)
213
+ pipe.to("cuda")
214
+
215
+ from safetensors import safe_open
216
+
217
+ lora_state_dict = {}
218
+ with safe_open("./ckpts/t2i/lora/jade/adapter_model.safetensors", framework="pt", device=0) as f:
219
+ for k in f.keys():
220
+ lora_state_dict[k[17:]] = f.get_tensor(k) # remove 'basemodel.model'
221
+
222
+ transformer_state_dict = pipe.transformer.state_dict()
223
+ transformer_state_dict = load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale=1.0)
224
+ pipe.transformer.load_state_dict(transformer_state_dict)
225
+
226
+ prompt = "玉石绘画风格,一只猫在追蝴蝶"
227
+ image = pipe(
228
+ prompt,
229
+ num_inference_steps=100,
230
+ guidance_scale=6.0,
231
+ ).images[0]
232
+ image.save('img.png')
233
+ ```
234
+
235
  More example prompts can be found in [example_prompts.txt](example_prompts.txt)