Safetensors
aredden commited on
Commit
a71da07
·
1 Parent(s): af20799

Add lora loading

Browse files
Files changed (3) hide show
  1. README.md +19 -0
  2. flux_pipeline.py +25 -7
  3. lora_loading.py +443 -0
README.md CHANGED
@@ -54,6 +54,25 @@ Note:
54
 
55
  **note:** prequantized flow models will only work with the specified quantization levels as when they were created. e.g. if you create a prequantized flow model with `quantize_modulation` set to false, it will only work with `quantize_modulation` set to false, same with `quantize_flow_embedder_layers`.
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ## Installation
58
 
59
  This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
 
54
 
55
  **note:** prequantized flow models will only work with the specified quantization levels as when they were created. e.g. if you create a prequantized flow model with `quantize_modulation` set to false, it will only work with `quantize_modulation` set to false, same with `quantize_flow_embedder_layers`.
56
 
57
+ ### Updates 08/25/24
58
+
59
+ - Added LoRA loading functionality to FluxPipeline. Simple example:
60
+
61
+ ```python
62
+ from flux_pipeline import FluxPipeline
63
+
64
+ config_path = "path/to/config/file.json"
65
+ config_overrides = {
66
+ #...
67
+ }
68
+
69
+ lora_path = "path/to/lora/file.safetensors"
70
+
71
+ pipeline = FluxPipeline.load_pipeline_from_config_path(config_path, **config_overrides)
72
+
73
+ pipeline.load_lora(lora_path, scale=1.0)
74
+ ```
75
+
76
  ## Installation
77
 
78
  This repo _requires_ at least pytorch with cuda=12.4 and an ADA gpu with fp8 support, otherwise `torch._scaled_mm` will throw a CUDA error saying it's not supported. To install with conda/mamba:
flux_pipeline.py CHANGED
@@ -1,17 +1,18 @@
1
  import io
2
  import math
3
  import random
 
4
  from typing import TYPE_CHECKING, Callable, List
5
- from PIL import Image
6
  import numpy as np
7
- import warnings
8
 
9
  warnings.filterwarnings("ignore", category=UserWarning)
10
  warnings.filterwarnings("ignore", category=FutureWarning)
11
  warnings.filterwarnings("ignore", category=DeprecationWarning)
12
  import torch
13
-
14
  from einops import rearrange
 
15
  from flux_emphasis import get_weighted_text_embeddings_flux
16
 
17
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -19,16 +20,20 @@ torch.backends.cudnn.allow_tf32 = True
19
  torch.backends.cudnn.benchmark = True
20
  torch.backends.cudnn.benchmark_limit = 20
21
  torch.set_float32_matmul_precision("high")
 
22
  from torch._dynamo import config
23
  from torch._inductor import config as ind_config
24
- from pybase64 import standard_b64decode
25
 
26
  config.cache_size_limit = 10000000000
27
  ind_config.shape_padding = True
 
 
28
  from loguru import logger
29
- from image_encoder import ImageEncoder
30
  from torchvision.transforms import functional as TF
31
  from tqdm import tqdm
 
 
 
32
  from util import (
33
  ModelSpec,
34
  ModelVersion,
@@ -37,7 +42,6 @@ from util import (
37
  load_config_from_path,
38
  load_models_from_config,
39
  )
40
- import platform
41
 
42
  if platform.system() == "Windows":
43
  MAX_RAND = 2**16 - 1
@@ -46,9 +50,9 @@ else:
46
 
47
 
48
  if TYPE_CHECKING:
 
49
  from modules.conditioner import HFEmbedder
50
  from modules.flux_model import Flux
51
- from modules.autoencoder import AutoEncoder
52
 
53
 
54
  class FluxPipeline:
@@ -144,6 +148,20 @@ class FluxPipeline:
144
  random.seed(seed)
145
  return cuda_generator, seed
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  @torch.inference_mode()
148
  def compile(self):
149
  """
 
1
  import io
2
  import math
3
  import random
4
+ import warnings
5
  from typing import TYPE_CHECKING, Callable, List
6
+
7
  import numpy as np
8
+ from PIL import Image
9
 
10
  warnings.filterwarnings("ignore", category=UserWarning)
11
  warnings.filterwarnings("ignore", category=FutureWarning)
12
  warnings.filterwarnings("ignore", category=DeprecationWarning)
13
  import torch
 
14
  from einops import rearrange
15
+
16
  from flux_emphasis import get_weighted_text_embeddings_flux
17
 
18
  torch.backends.cuda.matmul.allow_tf32 = True
 
20
  torch.backends.cudnn.benchmark = True
21
  torch.backends.cudnn.benchmark_limit = 20
22
  torch.set_float32_matmul_precision("high")
23
+ from pybase64 import standard_b64decode
24
  from torch._dynamo import config
25
  from torch._inductor import config as ind_config
 
26
 
27
  config.cache_size_limit = 10000000000
28
  ind_config.shape_padding = True
29
+ import platform
30
+
31
  from loguru import logger
 
32
  from torchvision.transforms import functional as TF
33
  from tqdm import tqdm
34
+
35
+ import lora_loading
36
+ from image_encoder import ImageEncoder
37
  from util import (
38
  ModelSpec,
39
  ModelVersion,
 
42
  load_config_from_path,
43
  load_models_from_config,
44
  )
 
45
 
46
  if platform.system() == "Windows":
47
  MAX_RAND = 2**16 - 1
 
50
 
51
 
52
  if TYPE_CHECKING:
53
+ from modules.autoencoder import AutoEncoder
54
  from modules.conditioner import HFEmbedder
55
  from modules.flux_model import Flux
 
56
 
57
 
58
  class FluxPipeline:
 
148
  random.seed(seed)
149
  return cuda_generator, seed
150
 
151
+ def load_lora(self, lora_path: str, scale: float):
152
+ """
153
+ Loads a LoRA checkpoint into the Flux flow transformer.
154
+
155
+ Currently supports LoRA checkpoints from either diffusers checkpoints which usually start with transformer.[...],
156
+ or loras which contain keys which start with lora_unet_[...].
157
+
158
+ Args:
159
+ lora_path (str): Path to the LoRA checkpoint.
160
+ scale (float): Scaling factor for the LoRA weights.
161
+
162
+ """
163
+ self.model = lora_loading.apply_lora_to_model(self.model, lora_path, scale)
164
+
165
  @torch.inference_mode()
166
  def compile(self):
167
  """
lora_loading.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from cublas_ops import CublasLinear
3
+ from loguru import logger
4
+ from safetensors.torch import load_file
5
+ from tqdm import tqdm
6
+
7
+ from float8_quantize import F8Linear
8
+ from modules.flux_model import Flux
9
+
10
+
11
+ def swap_scale_shift(weight):
12
+ scale, shift = weight.chunk(2, dim=0)
13
+ new_weight = torch.cat([shift, scale], dim=0)
14
+ return new_weight
15
+
16
+
17
+ def check_if_lora_exists(state_dict, lora_name):
18
+ subkey = lora_name.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0]
19
+ for key in state_dict.keys():
20
+ if subkey in key:
21
+ return subkey
22
+ return False
23
+
24
+
25
+ def convert_if_lora_exists(new_state_dict, state_dict, lora_name, flux_layer_name):
26
+ if (original_stubkey := check_if_lora_exists(state_dict, lora_name)) != False:
27
+ weights_to_pop = [k for k in state_dict.keys() if original_stubkey in k]
28
+ for key in weights_to_pop:
29
+ key_replacement = key.replace(
30
+ original_stubkey, flux_layer_name.replace(".weight", "")
31
+ )
32
+ new_state_dict[key_replacement] = state_dict.pop(key)
33
+ return new_state_dict, state_dict
34
+ else:
35
+ return new_state_dict, state_dict
36
+
37
+
38
+ def convert_diffusers_to_flux_transformer_checkpoint(
39
+ diffusers_state_dict,
40
+ num_layers,
41
+ num_single_layers,
42
+ has_guidance=True,
43
+ prefix="",
44
+ ):
45
+ original_state_dict = {}
46
+
47
+ # time_text_embed.timestep_embedder -> time_in
48
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
49
+ original_state_dict,
50
+ diffusers_state_dict,
51
+ f"{prefix}time_text_embed.timestep_embedder.linear_1.weight",
52
+ "time_in.in_layer.weight",
53
+ )
54
+ # time_text_embed.text_embedder -> vector_in
55
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
56
+ original_state_dict,
57
+ diffusers_state_dict,
58
+ f"{prefix}time_text_embed.text_embedder.linear_1.weight",
59
+ "vector_in.in_layer.weight",
60
+ )
61
+
62
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
63
+ original_state_dict,
64
+ diffusers_state_dict,
65
+ f"{prefix}time_text_embed.text_embedder.linear_2.weight",
66
+ "vector_in.out_layer.weight",
67
+ )
68
+
69
+ if has_guidance:
70
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
71
+ original_state_dict,
72
+ diffusers_state_dict,
73
+ f"{prefix}time_text_embed.guidance_embedder.linear_1.weight",
74
+ "guidance_in.in_layer.weight",
75
+ )
76
+
77
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
78
+ original_state_dict,
79
+ diffusers_state_dict,
80
+ f"{prefix}time_text_embed.guidance_embedder.linear_2.weight",
81
+ "guidance_in.out_layer.weight",
82
+ )
83
+
84
+ # context_embedder -> txt_in
85
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
86
+ original_state_dict,
87
+ diffusers_state_dict,
88
+ f"{prefix}context_embedder.weight",
89
+ "txt_in.weight",
90
+ )
91
+
92
+ # x_embedder -> img_in
93
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
94
+ original_state_dict,
95
+ diffusers_state_dict,
96
+ f"{prefix}x_embedder.weight",
97
+ "img_in.weight",
98
+ )
99
+ # double transformer blocks
100
+ for i in range(num_layers):
101
+ block_prefix = f"transformer_blocks.{i}."
102
+ # norms
103
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
104
+ original_state_dict,
105
+ diffusers_state_dict,
106
+ f"{prefix}{block_prefix}norm1.linear.weight",
107
+ f"double_blocks.{i}.img_mod.lin.weight",
108
+ )
109
+
110
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
111
+ original_state_dict,
112
+ diffusers_state_dict,
113
+ f"{prefix}{block_prefix}norm1_context.linear.weight",
114
+ f"double_blocks.{i}.txt_mod.lin.weight",
115
+ )
116
+
117
+ sample_q_A = diffusers_state_dict.pop(
118
+ f"{prefix}{block_prefix}attn.to_q.lora_A.weight"
119
+ )
120
+ sample_q_B = diffusers_state_dict.pop(
121
+ f"{prefix}{block_prefix}attn.to_q.lora_B.weight"
122
+ )
123
+
124
+ sample_k_A = diffusers_state_dict.pop(
125
+ f"{prefix}{block_prefix}attn.to_k.lora_A.weight"
126
+ )
127
+ sample_k_B = diffusers_state_dict.pop(
128
+ f"{prefix}{block_prefix}attn.to_k.lora_B.weight"
129
+ )
130
+
131
+ sample_v_A = diffusers_state_dict.pop(
132
+ f"{prefix}{block_prefix}attn.to_v.lora_A.weight"
133
+ )
134
+ sample_v_B = diffusers_state_dict.pop(
135
+ f"{prefix}{block_prefix}attn.to_v.lora_B.weight"
136
+ )
137
+
138
+ context_q_A = diffusers_state_dict.pop(
139
+ f"{prefix}{block_prefix}attn.add_q_proj.lora_A.weight"
140
+ )
141
+ context_q_B = diffusers_state_dict.pop(
142
+ f"{prefix}{block_prefix}attn.add_q_proj.lora_B.weight"
143
+ )
144
+
145
+ context_k_A = diffusers_state_dict.pop(
146
+ f"{prefix}{block_prefix}attn.add_k_proj.lora_A.weight"
147
+ )
148
+ context_k_B = diffusers_state_dict.pop(
149
+ f"{prefix}{block_prefix}attn.add_k_proj.lora_B.weight"
150
+ )
151
+ context_v_A = diffusers_state_dict.pop(
152
+ f"{prefix}{block_prefix}attn.add_v_proj.lora_A.weight"
153
+ )
154
+ context_v_B = diffusers_state_dict.pop(
155
+ f"{prefix}{block_prefix}attn.add_v_proj.lora_B.weight"
156
+ )
157
+
158
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_A.weight"] = (
159
+ torch.cat([sample_q_A, sample_k_A, sample_v_A], dim=0)
160
+ )
161
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.lora_B.weight"] = (
162
+ torch.cat([sample_q_B, sample_k_B, sample_v_B], dim=0)
163
+ )
164
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_A.weight"] = (
165
+ torch.cat([context_q_A, context_k_A, context_v_A], dim=0)
166
+ )
167
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.lora_B.weight"] = (
168
+ torch.cat([context_q_B, context_k_B, context_v_B], dim=0)
169
+ )
170
+
171
+ # qk_norm
172
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
173
+ original_state_dict,
174
+ diffusers_state_dict,
175
+ f"{prefix}{block_prefix}attn.norm_q.weight",
176
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale",
177
+ )
178
+
179
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
180
+ original_state_dict,
181
+ diffusers_state_dict,
182
+ f"{prefix}{block_prefix}attn.norm_k.weight",
183
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale",
184
+ )
185
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
186
+ original_state_dict,
187
+ diffusers_state_dict,
188
+ f"{prefix}{block_prefix}attn.norm_added_q.weight",
189
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale",
190
+ )
191
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
192
+ original_state_dict,
193
+ diffusers_state_dict,
194
+ f"{prefix}{block_prefix}attn.norm_added_k.weight",
195
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale",
196
+ )
197
+
198
+ # ff img_mlp
199
+
200
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
201
+ original_state_dict,
202
+ diffusers_state_dict,
203
+ f"{prefix}{block_prefix}ff.net.0.proj.weight",
204
+ f"double_blocks.{i}.img_mlp.0.weight",
205
+ )
206
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
207
+ original_state_dict,
208
+ diffusers_state_dict,
209
+ f"{prefix}{block_prefix}ff.net.2.weight",
210
+ f"double_blocks.{i}.img_mlp.2.weight",
211
+ )
212
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
213
+ original_state_dict,
214
+ diffusers_state_dict,
215
+ f"{prefix}{block_prefix}ff_context.net.0.proj.weight",
216
+ f"double_blocks.{i}.txt_mlp.0.weight",
217
+ )
218
+
219
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
220
+ original_state_dict,
221
+ diffusers_state_dict,
222
+ f"{prefix}{block_prefix}ff_context.net.2.weight",
223
+ f"double_blocks.{i}.txt_mlp.2.weight",
224
+ )
225
+ # output projections
226
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
227
+ original_state_dict,
228
+ diffusers_state_dict,
229
+ f"{prefix}{block_prefix}attn.to_out.0.weight",
230
+ f"double_blocks.{i}.img_attn.proj.weight",
231
+ )
232
+
233
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
234
+ original_state_dict,
235
+ diffusers_state_dict,
236
+ f"{prefix}{block_prefix}attn.to_add_out.weight",
237
+ f"double_blocks.{i}.txt_attn.proj.weight",
238
+ )
239
+
240
+ # single transformer blocks
241
+ for i in range(num_single_layers):
242
+ block_prefix = f"single_transformer_blocks.{i}."
243
+ # norm.linear -> single_blocks.0.modulation.lin
244
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
245
+ original_state_dict,
246
+ diffusers_state_dict,
247
+ f"{prefix}{block_prefix}norm.linear.weight",
248
+ f"single_blocks.{i}.modulation.lin.weight",
249
+ )
250
+
251
+ # Q, K, V, mlp
252
+ q_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
253
+ q_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
254
+ k_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
255
+ k_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
256
+ v_A = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
257
+ v_B = diffusers_state_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
258
+ mlp_A = diffusers_state_dict.pop(
259
+ f"{prefix}{block_prefix}proj_mlp.lora_A.weight"
260
+ )
261
+ mlp_B = diffusers_state_dict.pop(
262
+ f"{prefix}{block_prefix}proj_mlp.lora_B.weight"
263
+ )
264
+ original_state_dict[f"single_blocks.{i}.linear1.lora_A.weight"] = torch.cat(
265
+ [q_A, k_A, v_A, mlp_A], dim=0
266
+ )
267
+ original_state_dict[f"single_blocks.{i}.linear1.lora_B.weight"] = torch.cat(
268
+ [q_B, k_B, v_B, mlp_B], dim=0
269
+ )
270
+
271
+ # output projections
272
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
273
+ original_state_dict,
274
+ diffusers_state_dict,
275
+ f"{prefix}{block_prefix}proj_out.weight",
276
+ f"single_blocks.{i}.linear2.weight",
277
+ )
278
+
279
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
280
+ original_state_dict,
281
+ diffusers_state_dict,
282
+ f"{prefix}proj_out.weight",
283
+ "final_layer.linear.weight",
284
+ )
285
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
286
+ original_state_dict,
287
+ diffusers_state_dict,
288
+ f"{prefix}proj_out.bias",
289
+ "final_layer.linear.bias",
290
+ )
291
+ original_state_dict, diffusers_state_dict = convert_if_lora_exists(
292
+ original_state_dict,
293
+ diffusers_state_dict,
294
+ f"{prefix}norm_out.linear.weight",
295
+ "final_layer.adaLN_modulation.1.weight",
296
+ )
297
+ if len(list(diffusers_state_dict.keys())) > 0:
298
+ logger.warning("Unexpected keys:", diffusers_state_dict.keys())
299
+
300
+ return original_state_dict
301
+
302
+
303
+ def convert_from_original_flux_checkpoint(
304
+ original_state_dict,
305
+ ):
306
+ sd = {
307
+ k.replace("lora_unet_", "")
308
+ .replace("double_blocks_", "double_blocks.")
309
+ .replace("single_blocks_", "single_blocks.")
310
+ .replace("_img_attn_", ".img_attn.")
311
+ .replace("_txt_attn_", ".txt_attn.")
312
+ .replace("_img_mod_", ".img_mod.")
313
+ .replace("_txt_mod_", ".txt_mod.")
314
+ .replace("_img_mlp_", ".img_mlp.")
315
+ .replace("_txt_mlp_", ".txt_mlp.")
316
+ .replace("_linear1", ".linear1")
317
+ .replace("_linear2", ".linear2")
318
+ .replace("_modulation_", ".modulation.")
319
+ .replace("lora_up", "lora_B")
320
+ .replace("lora_down", "lora_A"): v
321
+ for k, v in original_state_dict.items()
322
+ if "lora" in k
323
+ }
324
+ return sd
325
+
326
+
327
+ def get_module_for_key(
328
+ key: str, model: Flux
329
+ ) -> F8Linear | torch.nn.Linear | CublasLinear:
330
+ parts = key.split(".")
331
+ module = model
332
+ for part in parts:
333
+ module = getattr(module, part)
334
+ return module
335
+
336
+
337
+ def get_lora_for_key(key: str, lora_weights: dict):
338
+ prefix = key.split(".lora")[0]
339
+ lora_A = lora_weights[f"{prefix}.lora_A.weight"]
340
+ lora_B = lora_weights[f"{prefix}.lora_B.weight"]
341
+ alpha = lora_weights.get(f"{prefix}.alpha", 1.0)
342
+ return lora_A, lora_B, alpha
343
+
344
+
345
+ @torch.inference_mode()
346
+ def apply_lora_weight_to_module(
347
+ module_weight: torch.Tensor,
348
+ lora_weights: dict,
349
+ rank: int = None,
350
+ lora_scale: float = 1.0,
351
+ ):
352
+ lora_A, lora_B, alpha = lora_weights
353
+
354
+ uneven_rank = lora_B.shape[1] != lora_A.shape[0]
355
+ rank_diff = lora_A.shape[0] / lora_B.shape[1]
356
+
357
+ if rank is None:
358
+ rank = lora_B.shape[1]
359
+ else:
360
+ rank = rank
361
+ if alpha is None:
362
+ alpha = rank
363
+ else:
364
+ alpha = alpha
365
+ w_dtype = module_weight.dtype
366
+ dtype = torch.float32
367
+ device = module_weight.device
368
+ w_orig = module_weight.to(dtype=dtype, device=device)
369
+ w_up = lora_A.to(dtype=dtype, device=device)
370
+ w_down = lora_B.to(dtype=dtype, device=device)
371
+
372
+ # if not from_original_flux:
373
+ if alpha != rank:
374
+ w_up = w_up * alpha / rank
375
+ if uneven_rank:
376
+ fused_lora = lora_scale * torch.mm(
377
+ w_down.repeat_interleave(int(rank_diff), dim=1), w_up
378
+ )
379
+ else:
380
+ fused_lora = lora_scale * torch.mm(w_down, w_up)
381
+ fused_weight = w_orig + fused_lora
382
+ return fused_weight.to(dtype=w_dtype, device=device)
383
+
384
+
385
+ @torch.inference_mode()
386
+ def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0):
387
+ has_guidance = model.params.guidance_embed
388
+ logger.info(f"Loading LoRA weights for {lora_path}")
389
+ lora_weights = load_file(lora_path)
390
+ from_original_flux = False
391
+ check_if_starts_with_transformer = [
392
+ k for k in lora_weights.keys() if k.startswith("transformer.")
393
+ ]
394
+ if len(check_if_starts_with_transformer) > 0:
395
+ lora_weights = convert_diffusers_to_flux_transformer_checkpoint(
396
+ lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
397
+ )
398
+ else:
399
+ from_original_flux = True
400
+ lora_weights = convert_from_original_flux_checkpoint(lora_weights)
401
+ logger.info("LoRA weights loaded")
402
+ logger.debug("Extracting keys")
403
+ keys_without_ab = [
404
+ key.replace(".lora_A.weight", "")
405
+ .replace(".lora_B.weight", "")
406
+ .replace(".alpha", "")
407
+ for key in lora_weights.keys()
408
+ ]
409
+ logger.debug("Keys extracted")
410
+ keys_without_ab = list(set(keys_without_ab))
411
+ if len(keys_without_ab) > 0:
412
+ logger.warning("Missing unconverted state dict keys!", len(keys_without_ab))
413
+
414
+ for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
415
+ module = get_module_for_key(key, model)
416
+ dtype = model.dtype
417
+ weight_is_f8 = False
418
+ if isinstance(module, F8Linear):
419
+ weight_is_f8 = True
420
+ weight_f16 = (
421
+ module.float8_data.clone()
422
+ .detach()
423
+ .float()
424
+ .mul(module.scale_reciprocal)
425
+ .to(module.weight.device)
426
+ )
427
+ elif isinstance(module, torch.nn.Linear):
428
+ weight_f16 = module.weight.clone().detach().float()
429
+ elif isinstance(module, CublasLinear):
430
+ weight_f16 = module.weight.clone().detach().float()
431
+ lora_sd = get_lora_for_key(key, lora_weights)
432
+ weight_f16 = apply_lora_weight_to_module(
433
+ weight_f16,
434
+ lora_sd,
435
+ lora_scale=lora_scale,
436
+ from_original_flux=from_original_flux,
437
+ )
438
+ if weight_is_f8:
439
+ module.set_weight_tensor(weight_f16.type(dtype))
440
+ else:
441
+ module.weight.data = weight_f16.type(dtype)
442
+ logger.success("Lora applied")
443
+ return model