Safetensors
aredden commited on
Commit
d33c2e4
·
1 Parent(s): d316f04

lora unloading initial

Browse files
Files changed (4) hide show
  1. lora_loading.py +151 -43
  2. modules/conditioner.py +2 -1
  3. modules/flux_model.py +61 -5
  4. util.py +5 -1
lora_loading.py CHANGED
@@ -1,7 +1,10 @@
 
 
1
  import torch
2
  from loguru import logger
3
  from safetensors.torch import load_file
4
  from tqdm import tqdm
 
5
 
6
  try:
7
  from cublas_ops import CublasLinear
@@ -10,6 +13,24 @@ except Exception as e:
10
  from float8_quantize import F8Linear
11
  from modules.flux_model import Flux
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def swap_scale_shift(weight):
15
  scale, shift = weight.chunk(2, dim=0)
@@ -345,52 +366,74 @@ def get_lora_for_key(key: str, lora_weights: dict):
345
  return lora_A, lora_B, alpha
346
 
347
 
348
- @torch.inference_mode()
349
- def apply_lora_weight_to_module(
350
- module_weight: torch.Tensor,
351
- lora_weights: dict,
352
- rank: int = None,
353
  lora_scale: float = 1.0,
 
354
  ):
355
  lora_A, lora_B, alpha = lora_weights
 
 
356
 
357
  uneven_rank = lora_B.shape[1] != lora_A.shape[0]
358
  rank_diff = lora_A.shape[0] / lora_B.shape[1]
359
 
360
  if rank is None:
361
  rank = lora_B.shape[1]
362
- else:
363
- rank = rank
364
  if alpha is None:
365
  alpha = rank
366
- else:
367
- alpha = alpha
368
- w_dtype = module_weight.dtype
369
  dtype = torch.float32
370
- device = module_weight.device
371
- w_orig = module_weight.to(dtype=dtype, device=device)
372
  w_up = lora_A.to(dtype=dtype, device=device)
373
  w_down = lora_B.to(dtype=dtype, device=device)
374
 
375
- # if not from_original_flux:
376
  if alpha != rank:
377
- w_up = w_up * alpha / rank
 
378
  if uneven_rank:
379
  fused_lora = lora_scale * torch.mm(
380
  w_down.repeat_interleave(int(rank_diff), dim=1), w_up
381
  )
382
  else:
383
  fused_lora = lora_scale * torch.mm(w_down, w_up)
384
- fused_weight = w_orig + fused_lora
385
- return fused_weight.to(dtype=w_dtype, device=device)
386
 
387
 
388
  @torch.inference_mode()
389
- def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0) -> Flux:
390
- has_guidance = model.params.guidance_embed
391
- logger.info(f"Loading LoRA weights for {lora_path}")
392
- lora_weights = load_file(lora_path)
393
- from_original_flux = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  check_if_starts_with_transformer = [
395
  k for k in lora_weights.keys() if k.startswith("transformer.")
396
  ]
@@ -399,43 +442,108 @@ def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0) ->
399
  lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
400
  )
401
  else:
402
- from_original_flux = True
403
  lora_weights = convert_from_original_flux_checkpoint(lora_weights)
404
  logger.info("LoRA weights loaded")
405
  logger.debug("Extracting keys")
406
  keys_without_ab = [
407
  key.replace(".lora_A.weight", "")
408
  .replace(".lora_B.weight", "")
 
 
409
  .replace(".alpha", "")
410
  for key in lora_weights.keys()
411
  ]
412
  logger.debug("Keys extracted")
413
  keys_without_ab = list(set(keys_without_ab))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
  for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
416
  module = get_module_for_key(key, model)
417
- dtype = model.dtype
418
- weight_is_f8 = False
419
- if isinstance(module, F8Linear):
420
- weight_is_f8 = True
421
- weight_f16 = (
422
- module.float8_data.clone()
423
- .detach()
424
- .float()
425
- .mul(module.scale_reciprocal)
426
- .to(module.weight.device)
427
- )
428
- elif isinstance(module, torch.nn.Linear):
429
- weight_f16 = module.weight.clone().detach().float()
430
- elif isinstance(module, CublasLinear):
431
- weight_f16 = module.weight.clone().detach().float()
432
  lora_sd = get_lora_for_key(key, lora_weights)
433
- weight_f16 = apply_lora_weight_to_module(
434
- weight_f16, lora_sd, lora_scale=lora_scale
435
- )
436
- if weight_is_f8:
437
- module.set_weight_tensor(weight_f16.type(dtype))
438
  else:
439
- module.weight.data = weight_f16.type(dtype)
440
  logger.success("Lora applied")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  return model
 
1
+ import re
2
+ from typing import Optional, OrderedDict, Tuple, TypeAlias, Union
3
  import torch
4
  from loguru import logger
5
  from safetensors.torch import load_file
6
  from tqdm import tqdm
7
+ from torch import nn
8
 
9
  try:
10
  from cublas_ops import CublasLinear
 
13
  from float8_quantize import F8Linear
14
  from modules.flux_model import Flux
15
 
16
+ path_regex = re.compile(r"\/|\\")
17
+
18
+ StateDict: TypeAlias = OrderedDict[str, torch.Tensor]
19
+
20
+
21
+ class LoraWeights:
22
+ def __init__(
23
+ self,
24
+ weights: StateDict,
25
+ path: str,
26
+ name: str = None,
27
+ scale: float = 1.0,
28
+ ) -> None:
29
+ self.path = path
30
+ self.weights = weights
31
+ self.name = name if name else path_regex.split(path)[-1]
32
+ self.scale = scale
33
+
34
 
35
  def swap_scale_shift(weight):
36
  scale, shift = weight.chunk(2, dim=0)
 
366
  return lora_A, lora_B, alpha
367
 
368
 
369
+ def calculate_lora_weight(
370
+ lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]],
371
+ rank: Optional[int] = None,
 
 
372
  lora_scale: float = 1.0,
373
+ device: Optional[Union[torch.device, int, str]] = None,
374
  ):
375
  lora_A, lora_B, alpha = lora_weights
376
+ if device is None:
377
+ device = lora_A.device
378
 
379
  uneven_rank = lora_B.shape[1] != lora_A.shape[0]
380
  rank_diff = lora_A.shape[0] / lora_B.shape[1]
381
 
382
  if rank is None:
383
  rank = lora_B.shape[1]
 
 
384
  if alpha is None:
385
  alpha = rank
386
+
 
 
387
  dtype = torch.float32
 
 
388
  w_up = lora_A.to(dtype=dtype, device=device)
389
  w_down = lora_B.to(dtype=dtype, device=device)
390
 
 
391
  if alpha != rank:
392
+ w_up = w_up * (alpha / rank)
393
+
394
  if uneven_rank:
395
  fused_lora = lora_scale * torch.mm(
396
  w_down.repeat_interleave(int(rank_diff), dim=1), w_up
397
  )
398
  else:
399
  fused_lora = lora_scale * torch.mm(w_down, w_up)
400
+ return fused_lora
 
401
 
402
 
403
  @torch.inference_mode()
404
+ def unfuse_lora_weight_from_module(
405
+ fused_weight: torch.Tensor,
406
+ lora_weights: dict,
407
+ rank: Optional[int] = None,
408
+ lora_scale: float = 1.0,
409
+ ):
410
+ w_dtype = fused_weight.dtype
411
+ dtype = torch.float32
412
+ device = fused_weight.device
413
+
414
+ fused_weight = fused_weight.to(dtype=dtype, device=device)
415
+ fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
416
+ module_weight = fused_weight - fused_lora
417
+ return module_weight.to(dtype=w_dtype, device=device)
418
+
419
+
420
+ @torch.inference_mode()
421
+ def apply_lora_weight_to_module(
422
+ module_weight: torch.Tensor,
423
+ lora_weights: dict,
424
+ rank: int = None,
425
+ lora_scale: float = 1.0,
426
+ ):
427
+ w_dtype = module_weight.dtype
428
+ dtype = torch.float32
429
+ device = module_weight.device
430
+
431
+ fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
432
+ fused_weight = module_weight.to(dtype=dtype) + fused_lora
433
+ return fused_weight.to(dtype=w_dtype, device=device)
434
+
435
+
436
+ def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
437
  check_if_starts_with_transformer = [
438
  k for k in lora_weights.keys() if k.startswith("transformer.")
439
  ]
 
442
  lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
443
  )
444
  else:
 
445
  lora_weights = convert_from_original_flux_checkpoint(lora_weights)
446
  logger.info("LoRA weights loaded")
447
  logger.debug("Extracting keys")
448
  keys_without_ab = [
449
  key.replace(".lora_A.weight", "")
450
  .replace(".lora_B.weight", "")
451
+ .replace(".lora_A", "")
452
+ .replace(".lora_B", "")
453
  .replace(".alpha", "")
454
  for key in lora_weights.keys()
455
  ]
456
  logger.debug("Keys extracted")
457
  keys_without_ab = list(set(keys_without_ab))
458
+ keys_without_ab = list(
459
+ set(
460
+ [
461
+ key.replace(".lora_A.weight", "")
462
+ .replace(".lora_B.weight", "")
463
+ .replace(".lora_A", "")
464
+ .replace(".lora_B", "")
465
+ .replace(".alpha", "")
466
+ for key in keys_without_ab
467
+ ]
468
+ )
469
+ )
470
+ return keys_without_ab, lora_weights
471
+
472
+
473
+ def get_lora_weights(lora_path: str | StateDict):
474
+ if isinstance(lora_path, dict):
475
+ return lora_path, True
476
+ else:
477
+ return load_file(lora_path, "cpu"), False
478
+
479
+
480
+ def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear]):
481
+ dtype = linear.weight.dtype
482
+ weight_is_f8 = False
483
+ if isinstance(linear, F8Linear):
484
+ weight_is_f8 = True
485
+ weight = (
486
+ linear.float8_data.clone()
487
+ .detach()
488
+ .float()
489
+ .mul(linear.scale_reciprocal)
490
+ .to(linear.weight.device)
491
+ )
492
+ elif isinstance(linear, torch.nn.Linear):
493
+ weight = linear.weight.clone().detach().float()
494
+ elif isinstance(linear, CublasLinear):
495
+ weight = linear.weight.clone().detach().float()
496
+ return weight, weight_is_f8, dtype
497
+
498
+
499
+ @torch.inference_mode()
500
+ def apply_lora_to_model(
501
+ model: Flux,
502
+ lora_path: str | StateDict,
503
+ lora_scale: float = 1.0,
504
+ return_lora_resolved: bool = False,
505
+ ) -> Flux:
506
+ has_guidance = model.params.guidance_embed
507
+ logger.info(f"Loading LoRA weights for {lora_path}")
508
+ lora_weights, _ = get_lora_weights(lora_path)
509
+
510
+ keys_without_ab, lora_weights = resolve_lora_state_dict(lora_weights, has_guidance)
511
 
512
  for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
513
  module = get_module_for_key(key, model)
514
+ weight, is_f8, dtype = extract_weight_from_linear(module)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  lora_sd = get_lora_for_key(key, lora_weights)
516
+ weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale)
517
+ if is_f8:
518
+ module.set_weight_tensor(weight.type(dtype))
 
 
519
  else:
520
+ module.weight.data = weight.type(dtype)
521
  logger.success("Lora applied")
522
+ if return_lora_resolved:
523
+ return model, lora_weights
524
+ return model
525
+
526
+
527
+ def remove_lora_from_module(
528
+ model: Flux,
529
+ lora_path: str | StateDict,
530
+ lora_scale: float = 1.0,
531
+ ):
532
+ has_guidance = model.params.guidance_embed
533
+ logger.info(f"Loading LoRA weights for {lora_path}")
534
+ lora_weights = get_lora_weights(lora_path)
535
+ lora_weights, _ = get_lora_weights(lora_path)
536
+
537
+ keys_without_ab, lora_weights = resolve_lora_state_dict(lora_weights, has_guidance)
538
+
539
+ for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
540
+ module = get_module_for_key(key, model)
541
+ weight, is_f8, dtype = extract_weight_from_linear(module)
542
+ lora_sd = get_lora_for_key(key, lora_weights)
543
+ weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale)
544
+ if is_f8:
545
+ module.set_weight_tensor(weight.type(dtype))
546
+ else:
547
+ module.weight.data = weight.type(dtype)
548
+ logger.success("Lora unfused")
549
  return model
modules/conditioner.py CHANGED
@@ -43,6 +43,7 @@ class HFEmbedder(nn.Module):
43
  device: torch.device | int,
44
  quantization_dtype: str | None = None,
45
  offloading_device: torch.device | int | None = torch.device("cpu"),
 
46
  **hf_kwargs,
47
  ):
48
  super().__init__()
@@ -54,7 +55,7 @@ class HFEmbedder(nn.Module):
54
  self.device = (
55
  device if isinstance(device, torch.device) else torch.device(device)
56
  )
57
- self.is_clip = version.startswith("openai")
58
  self.max_length = max_length
59
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
60
 
 
43
  device: torch.device | int,
44
  quantization_dtype: str | None = None,
45
  offloading_device: torch.device | int | None = torch.device("cpu"),
46
+ is_clip: bool = False,
47
  **hf_kwargs,
48
  ):
49
  super().__init__()
 
55
  self.device = (
56
  device if isinstance(device, torch.device) else torch.device(device)
57
  )
58
+ self.is_clip = version.startswith("openai") or is_clip
59
  self.max_length = max_length
60
  self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
61
 
modules/flux_model.py CHANGED
@@ -1,11 +1,13 @@
1
- from collections import namedtuple
2
  import os
3
- from typing import TYPE_CHECKING
 
 
4
  import torch
 
5
 
6
  if TYPE_CHECKING:
 
7
  from util import ModelSpec
8
-
9
  DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
10
  torch.backends.cuda.matmul.allow_tf32 = True
11
  torch.backends.cudnn.allow_tf32 = True
@@ -14,8 +16,8 @@ torch.backends.cudnn.benchmark_limit = 20
14
  torch.set_float32_matmul_precision("high")
15
  import math
16
 
17
- from torch import Tensor, nn
18
  from pydantic import BaseModel
 
19
  from torch.nn import functional as F
20
 
21
 
@@ -345,6 +347,7 @@ class DoubleStreamBlock(nn.Module):
345
  self.H = self.num_heads
346
  self.KH = self.K * self.H
347
  self.do_clamp = dtype == torch.float16
 
348
  def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
349
  B, L, D = x.shape
350
  q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
@@ -512,6 +515,7 @@ class Flux(nn.Module):
512
  self.params = config.params
513
  self.in_channels = config.params.in_channels
514
  self.out_channels = self.in_channels
 
515
  prequantized_flow = config.prequantized_flow
516
  quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
517
  quantized_modulation = config.quantize_modulation and prequantized_flow
@@ -614,6 +618,57 @@ class Flux(nn.Module):
614
 
615
  self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  def forward(
618
  self,
619
  img: Tensor,
@@ -664,9 +719,10 @@ class Flux(nn.Module):
664
  def from_pretrained(
665
  cls: "Flux", path: str, dtype: torch.dtype = torch.float16
666
  ) -> "Flux":
667
- from util import load_config_from_path
668
  from safetensors.torch import load_file
669
 
 
 
670
  config = load_config_from_path(path)
671
  with torch.device("meta"):
672
  klass = cls(config=config, dtype=dtype)
 
 
1
  import os
2
+ from collections import namedtuple
3
+ from typing import TYPE_CHECKING, List
4
+
5
  import torch
6
+ from loguru import logger
7
 
8
  if TYPE_CHECKING:
9
+ from lora_loading import LoraWeights
10
  from util import ModelSpec
 
11
  DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
12
  torch.backends.cuda.matmul.allow_tf32 = True
13
  torch.backends.cudnn.allow_tf32 = True
 
16
  torch.set_float32_matmul_precision("high")
17
  import math
18
 
 
19
  from pydantic import BaseModel
20
+ from torch import Tensor, nn
21
  from torch.nn import functional as F
22
 
23
 
 
347
  self.H = self.num_heads
348
  self.KH = self.K * self.H
349
  self.do_clamp = dtype == torch.float16
350
+
351
  def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
352
  B, L, D = x.shape
353
  q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
 
515
  self.params = config.params
516
  self.in_channels = config.params.in_channels
517
  self.out_channels = self.in_channels
518
+ self.loras: List[LoraWeights] = []
519
  prequantized_flow = config.prequantized_flow
520
  quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
521
  quantized_modulation = config.quantize_modulation and prequantized_flow
 
618
 
619
  self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
620
 
621
+ def get_lora(self, identifier: str):
622
+ for lora in self.loras:
623
+ if lora.path == identifier or lora.name == identifier:
624
+ return lora
625
+
626
+ def has_lora(self, identifier: str):
627
+ for lora in self.loras:
628
+ if lora.path == identifier or lora.name == identifier:
629
+ return True
630
+
631
+ def load_lora(self, path: str, scale: float, name: str = None):
632
+ from lora_loading import (
633
+ LoraWeights,
634
+ apply_lora_to_model,
635
+ remove_lora_from_module,
636
+ )
637
+
638
+ if self.has_lora(path):
639
+ lora = self.get_lora(path)
640
+ if lora.scale == scale:
641
+ logger.warning(
642
+ f"Lora {lora.name} already loaded with same scale - ignoring!"
643
+ )
644
+ else:
645
+ remove_lora_from_module(self, lora, lora.scale)
646
+ apply_lora_to_model(self, lora, scale)
647
+ for idx, lora_ in enumerate(self.loras):
648
+ if lora_.path == lora.path:
649
+ self.loras[idx].scale = scale
650
+ break
651
+ else:
652
+ _, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True)
653
+ self.loras.append(LoraWeights(lora, path, name, scale))
654
+
655
+ def unload_lora(self, path_or_identifier: str):
656
+ from lora_loading import remove_lora_from_module
657
+
658
+ removed = False
659
+ for idx, lora_ in enumerate(list(self.loras)):
660
+ if lora_.path == path_or_identifier or lora_.name == path_or_identifier:
661
+ remove_lora_from_module(self, lora_.weights, lora_.scale)
662
+ self.loras.pop(idx)
663
+ removed = True
664
+ break
665
+ if not removed:
666
+ logger.warning(
667
+ f"Couldn't remove lora {path_or_identifier} as it wasn't found fused to the model!"
668
+ )
669
+ else:
670
+ logger.info("Successfully removed lora from module.")
671
+
672
  def forward(
673
  self,
674
  img: Tensor,
 
719
  def from_pretrained(
720
  cls: "Flux", path: str, dtype: torch.dtype = torch.float16
721
  ) -> "Flux":
 
722
  from safetensors.torch import load_file
723
 
724
+ from util import load_config_from_path
725
+
726
  config = load_config_from_path(path)
727
  with torch.device("meta"):
728
  klass = cls(config=config, dtype=dtype)
util.py CHANGED
@@ -34,11 +34,14 @@ class QuantizationDtype(StrEnum):
34
  bfloat16 = "bfloat16"
35
  float16 = "float16"
36
 
 
37
  class ModelSpec(BaseModel):
38
  version: ModelVersion
39
  params: FluxParams
40
  ae_params: AutoEncoderParams
41
  ckpt_path: str | None
 
 
42
  ae_path: str | None
43
  repo_id: str | None
44
  repo_flow: str | None
@@ -255,10 +258,11 @@ def load_flow_model(config: ModelSpec) -> Flux:
255
 
256
  def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
257
  clip = HFEmbedder(
258
- "openai/clip-vit-large-patch14",
259
  max_length=77,
260
  torch_dtype=into_dtype(config.text_enc_dtype),
261
  device=into_device(config.text_enc_device).index or 0,
 
262
  quantization_dtype=config.clip_quantization_dtype,
263
  )
264
  t5 = HFEmbedder(
 
34
  bfloat16 = "bfloat16"
35
  float16 = "float16"
36
 
37
+
38
  class ModelSpec(BaseModel):
39
  version: ModelVersion
40
  params: FluxParams
41
  ae_params: AutoEncoderParams
42
  ckpt_path: str | None
43
+ # Add option to pass in custom clip model
44
+ clip_path: str | None = "openai/clip-vit-large-patch14"
45
  ae_path: str | None
46
  repo_id: str | None
47
  repo_flow: str | None
 
258
 
259
  def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
260
  clip = HFEmbedder(
261
+ config.clip_path,
262
  max_length=77,
263
  torch_dtype=into_dtype(config.text_enc_dtype),
264
  device=into_device(config.text_enc_device).index or 0,
265
+ is_clip=True,
266
  quantization_dtype=config.clip_quantization_dtype,
267
  )
268
  t5 = HFEmbedder(