lora unloading initial
Browse files- lora_loading.py +151 -43
- modules/conditioner.py +2 -1
- modules/flux_model.py +61 -5
- util.py +5 -1
lora_loading.py
CHANGED
@@ -1,7 +1,10 @@
|
|
|
|
|
|
1 |
import torch
|
2 |
from loguru import logger
|
3 |
from safetensors.torch import load_file
|
4 |
from tqdm import tqdm
|
|
|
5 |
|
6 |
try:
|
7 |
from cublas_ops import CublasLinear
|
@@ -10,6 +13,24 @@ except Exception as e:
|
|
10 |
from float8_quantize import F8Linear
|
11 |
from modules.flux_model import Flux
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def swap_scale_shift(weight):
|
15 |
scale, shift = weight.chunk(2, dim=0)
|
@@ -345,52 +366,74 @@ def get_lora_for_key(key: str, lora_weights: dict):
|
|
345 |
return lora_A, lora_B, alpha
|
346 |
|
347 |
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
lora_weights: dict,
|
352 |
-
rank: int = None,
|
353 |
lora_scale: float = 1.0,
|
|
|
354 |
):
|
355 |
lora_A, lora_B, alpha = lora_weights
|
|
|
|
|
356 |
|
357 |
uneven_rank = lora_B.shape[1] != lora_A.shape[0]
|
358 |
rank_diff = lora_A.shape[0] / lora_B.shape[1]
|
359 |
|
360 |
if rank is None:
|
361 |
rank = lora_B.shape[1]
|
362 |
-
else:
|
363 |
-
rank = rank
|
364 |
if alpha is None:
|
365 |
alpha = rank
|
366 |
-
|
367 |
-
alpha = alpha
|
368 |
-
w_dtype = module_weight.dtype
|
369 |
dtype = torch.float32
|
370 |
-
device = module_weight.device
|
371 |
-
w_orig = module_weight.to(dtype=dtype, device=device)
|
372 |
w_up = lora_A.to(dtype=dtype, device=device)
|
373 |
w_down = lora_B.to(dtype=dtype, device=device)
|
374 |
|
375 |
-
# if not from_original_flux:
|
376 |
if alpha != rank:
|
377 |
-
w_up = w_up * alpha / rank
|
|
|
378 |
if uneven_rank:
|
379 |
fused_lora = lora_scale * torch.mm(
|
380 |
w_down.repeat_interleave(int(rank_diff), dim=1), w_up
|
381 |
)
|
382 |
else:
|
383 |
fused_lora = lora_scale * torch.mm(w_down, w_up)
|
384 |
-
|
385 |
-
return fused_weight.to(dtype=w_dtype, device=device)
|
386 |
|
387 |
|
388 |
@torch.inference_mode()
|
389 |
-
def
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
check_if_starts_with_transformer = [
|
395 |
k for k in lora_weights.keys() if k.startswith("transformer.")
|
396 |
]
|
@@ -399,43 +442,108 @@ def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0) ->
|
|
399 |
lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
|
400 |
)
|
401 |
else:
|
402 |
-
from_original_flux = True
|
403 |
lora_weights = convert_from_original_flux_checkpoint(lora_weights)
|
404 |
logger.info("LoRA weights loaded")
|
405 |
logger.debug("Extracting keys")
|
406 |
keys_without_ab = [
|
407 |
key.replace(".lora_A.weight", "")
|
408 |
.replace(".lora_B.weight", "")
|
|
|
|
|
409 |
.replace(".alpha", "")
|
410 |
for key in lora_weights.keys()
|
411 |
]
|
412 |
logger.debug("Keys extracted")
|
413 |
keys_without_ab = list(set(keys_without_ab))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
416 |
module = get_module_for_key(key, model)
|
417 |
-
dtype =
|
418 |
-
weight_is_f8 = False
|
419 |
-
if isinstance(module, F8Linear):
|
420 |
-
weight_is_f8 = True
|
421 |
-
weight_f16 = (
|
422 |
-
module.float8_data.clone()
|
423 |
-
.detach()
|
424 |
-
.float()
|
425 |
-
.mul(module.scale_reciprocal)
|
426 |
-
.to(module.weight.device)
|
427 |
-
)
|
428 |
-
elif isinstance(module, torch.nn.Linear):
|
429 |
-
weight_f16 = module.weight.clone().detach().float()
|
430 |
-
elif isinstance(module, CublasLinear):
|
431 |
-
weight_f16 = module.weight.clone().detach().float()
|
432 |
lora_sd = get_lora_for_key(key, lora_weights)
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
if weight_is_f8:
|
437 |
-
module.set_weight_tensor(weight_f16.type(dtype))
|
438 |
else:
|
439 |
-
module.weight.data =
|
440 |
logger.success("Lora applied")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
return model
|
|
|
1 |
+
import re
|
2 |
+
from typing import Optional, OrderedDict, Tuple, TypeAlias, Union
|
3 |
import torch
|
4 |
from loguru import logger
|
5 |
from safetensors.torch import load_file
|
6 |
from tqdm import tqdm
|
7 |
+
from torch import nn
|
8 |
|
9 |
try:
|
10 |
from cublas_ops import CublasLinear
|
|
|
13 |
from float8_quantize import F8Linear
|
14 |
from modules.flux_model import Flux
|
15 |
|
16 |
+
path_regex = re.compile(r"\/|\\")
|
17 |
+
|
18 |
+
StateDict: TypeAlias = OrderedDict[str, torch.Tensor]
|
19 |
+
|
20 |
+
|
21 |
+
class LoraWeights:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
weights: StateDict,
|
25 |
+
path: str,
|
26 |
+
name: str = None,
|
27 |
+
scale: float = 1.0,
|
28 |
+
) -> None:
|
29 |
+
self.path = path
|
30 |
+
self.weights = weights
|
31 |
+
self.name = name if name else path_regex.split(path)[-1]
|
32 |
+
self.scale = scale
|
33 |
+
|
34 |
|
35 |
def swap_scale_shift(weight):
|
36 |
scale, shift = weight.chunk(2, dim=0)
|
|
|
366 |
return lora_A, lora_B, alpha
|
367 |
|
368 |
|
369 |
+
def calculate_lora_weight(
|
370 |
+
lora_weights: Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, float]],
|
371 |
+
rank: Optional[int] = None,
|
|
|
|
|
372 |
lora_scale: float = 1.0,
|
373 |
+
device: Optional[Union[torch.device, int, str]] = None,
|
374 |
):
|
375 |
lora_A, lora_B, alpha = lora_weights
|
376 |
+
if device is None:
|
377 |
+
device = lora_A.device
|
378 |
|
379 |
uneven_rank = lora_B.shape[1] != lora_A.shape[0]
|
380 |
rank_diff = lora_A.shape[0] / lora_B.shape[1]
|
381 |
|
382 |
if rank is None:
|
383 |
rank = lora_B.shape[1]
|
|
|
|
|
384 |
if alpha is None:
|
385 |
alpha = rank
|
386 |
+
|
|
|
|
|
387 |
dtype = torch.float32
|
|
|
|
|
388 |
w_up = lora_A.to(dtype=dtype, device=device)
|
389 |
w_down = lora_B.to(dtype=dtype, device=device)
|
390 |
|
|
|
391 |
if alpha != rank:
|
392 |
+
w_up = w_up * (alpha / rank)
|
393 |
+
|
394 |
if uneven_rank:
|
395 |
fused_lora = lora_scale * torch.mm(
|
396 |
w_down.repeat_interleave(int(rank_diff), dim=1), w_up
|
397 |
)
|
398 |
else:
|
399 |
fused_lora = lora_scale * torch.mm(w_down, w_up)
|
400 |
+
return fused_lora
|
|
|
401 |
|
402 |
|
403 |
@torch.inference_mode()
|
404 |
+
def unfuse_lora_weight_from_module(
|
405 |
+
fused_weight: torch.Tensor,
|
406 |
+
lora_weights: dict,
|
407 |
+
rank: Optional[int] = None,
|
408 |
+
lora_scale: float = 1.0,
|
409 |
+
):
|
410 |
+
w_dtype = fused_weight.dtype
|
411 |
+
dtype = torch.float32
|
412 |
+
device = fused_weight.device
|
413 |
+
|
414 |
+
fused_weight = fused_weight.to(dtype=dtype, device=device)
|
415 |
+
fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
|
416 |
+
module_weight = fused_weight - fused_lora
|
417 |
+
return module_weight.to(dtype=w_dtype, device=device)
|
418 |
+
|
419 |
+
|
420 |
+
@torch.inference_mode()
|
421 |
+
def apply_lora_weight_to_module(
|
422 |
+
module_weight: torch.Tensor,
|
423 |
+
lora_weights: dict,
|
424 |
+
rank: int = None,
|
425 |
+
lora_scale: float = 1.0,
|
426 |
+
):
|
427 |
+
w_dtype = module_weight.dtype
|
428 |
+
dtype = torch.float32
|
429 |
+
device = module_weight.device
|
430 |
+
|
431 |
+
fused_lora = calculate_lora_weight(lora_weights, rank, lora_scale, device=device)
|
432 |
+
fused_weight = module_weight.to(dtype=dtype) + fused_lora
|
433 |
+
return fused_weight.to(dtype=w_dtype, device=device)
|
434 |
+
|
435 |
+
|
436 |
+
def resolve_lora_state_dict(lora_weights, has_guidance: bool = True):
|
437 |
check_if_starts_with_transformer = [
|
438 |
k for k in lora_weights.keys() if k.startswith("transformer.")
|
439 |
]
|
|
|
442 |
lora_weights, 19, 38, has_guidance=has_guidance, prefix="transformer."
|
443 |
)
|
444 |
else:
|
|
|
445 |
lora_weights = convert_from_original_flux_checkpoint(lora_weights)
|
446 |
logger.info("LoRA weights loaded")
|
447 |
logger.debug("Extracting keys")
|
448 |
keys_without_ab = [
|
449 |
key.replace(".lora_A.weight", "")
|
450 |
.replace(".lora_B.weight", "")
|
451 |
+
.replace(".lora_A", "")
|
452 |
+
.replace(".lora_B", "")
|
453 |
.replace(".alpha", "")
|
454 |
for key in lora_weights.keys()
|
455 |
]
|
456 |
logger.debug("Keys extracted")
|
457 |
keys_without_ab = list(set(keys_without_ab))
|
458 |
+
keys_without_ab = list(
|
459 |
+
set(
|
460 |
+
[
|
461 |
+
key.replace(".lora_A.weight", "")
|
462 |
+
.replace(".lora_B.weight", "")
|
463 |
+
.replace(".lora_A", "")
|
464 |
+
.replace(".lora_B", "")
|
465 |
+
.replace(".alpha", "")
|
466 |
+
for key in keys_without_ab
|
467 |
+
]
|
468 |
+
)
|
469 |
+
)
|
470 |
+
return keys_without_ab, lora_weights
|
471 |
+
|
472 |
+
|
473 |
+
def get_lora_weights(lora_path: str | StateDict):
|
474 |
+
if isinstance(lora_path, dict):
|
475 |
+
return lora_path, True
|
476 |
+
else:
|
477 |
+
return load_file(lora_path, "cpu"), False
|
478 |
+
|
479 |
+
|
480 |
+
def extract_weight_from_linear(linear: Union[nn.Linear, CublasLinear, F8Linear]):
|
481 |
+
dtype = linear.weight.dtype
|
482 |
+
weight_is_f8 = False
|
483 |
+
if isinstance(linear, F8Linear):
|
484 |
+
weight_is_f8 = True
|
485 |
+
weight = (
|
486 |
+
linear.float8_data.clone()
|
487 |
+
.detach()
|
488 |
+
.float()
|
489 |
+
.mul(linear.scale_reciprocal)
|
490 |
+
.to(linear.weight.device)
|
491 |
+
)
|
492 |
+
elif isinstance(linear, torch.nn.Linear):
|
493 |
+
weight = linear.weight.clone().detach().float()
|
494 |
+
elif isinstance(linear, CublasLinear):
|
495 |
+
weight = linear.weight.clone().detach().float()
|
496 |
+
return weight, weight_is_f8, dtype
|
497 |
+
|
498 |
+
|
499 |
+
@torch.inference_mode()
|
500 |
+
def apply_lora_to_model(
|
501 |
+
model: Flux,
|
502 |
+
lora_path: str | StateDict,
|
503 |
+
lora_scale: float = 1.0,
|
504 |
+
return_lora_resolved: bool = False,
|
505 |
+
) -> Flux:
|
506 |
+
has_guidance = model.params.guidance_embed
|
507 |
+
logger.info(f"Loading LoRA weights for {lora_path}")
|
508 |
+
lora_weights, _ = get_lora_weights(lora_path)
|
509 |
+
|
510 |
+
keys_without_ab, lora_weights = resolve_lora_state_dict(lora_weights, has_guidance)
|
511 |
|
512 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
513 |
module = get_module_for_key(key, model)
|
514 |
+
weight, is_f8, dtype = extract_weight_from_linear(module)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
lora_sd = get_lora_for_key(key, lora_weights)
|
516 |
+
weight = apply_lora_weight_to_module(weight, lora_sd, lora_scale=lora_scale)
|
517 |
+
if is_f8:
|
518 |
+
module.set_weight_tensor(weight.type(dtype))
|
|
|
|
|
519 |
else:
|
520 |
+
module.weight.data = weight.type(dtype)
|
521 |
logger.success("Lora applied")
|
522 |
+
if return_lora_resolved:
|
523 |
+
return model, lora_weights
|
524 |
+
return model
|
525 |
+
|
526 |
+
|
527 |
+
def remove_lora_from_module(
|
528 |
+
model: Flux,
|
529 |
+
lora_path: str | StateDict,
|
530 |
+
lora_scale: float = 1.0,
|
531 |
+
):
|
532 |
+
has_guidance = model.params.guidance_embed
|
533 |
+
logger.info(f"Loading LoRA weights for {lora_path}")
|
534 |
+
lora_weights = get_lora_weights(lora_path)
|
535 |
+
lora_weights, _ = get_lora_weights(lora_path)
|
536 |
+
|
537 |
+
keys_without_ab, lora_weights = resolve_lora_state_dict(lora_weights, has_guidance)
|
538 |
+
|
539 |
+
for key in tqdm(keys_without_ab, desc="Unfusing LoRA", total=len(keys_without_ab)):
|
540 |
+
module = get_module_for_key(key, model)
|
541 |
+
weight, is_f8, dtype = extract_weight_from_linear(module)
|
542 |
+
lora_sd = get_lora_for_key(key, lora_weights)
|
543 |
+
weight = unfuse_lora_weight_from_module(weight, lora_sd, lora_scale=lora_scale)
|
544 |
+
if is_f8:
|
545 |
+
module.set_weight_tensor(weight.type(dtype))
|
546 |
+
else:
|
547 |
+
module.weight.data = weight.type(dtype)
|
548 |
+
logger.success("Lora unfused")
|
549 |
return model
|
modules/conditioner.py
CHANGED
@@ -43,6 +43,7 @@ class HFEmbedder(nn.Module):
|
|
43 |
device: torch.device | int,
|
44 |
quantization_dtype: str | None = None,
|
45 |
offloading_device: torch.device | int | None = torch.device("cpu"),
|
|
|
46 |
**hf_kwargs,
|
47 |
):
|
48 |
super().__init__()
|
@@ -54,7 +55,7 @@ class HFEmbedder(nn.Module):
|
|
54 |
self.device = (
|
55 |
device if isinstance(device, torch.device) else torch.device(device)
|
56 |
)
|
57 |
-
self.is_clip = version.startswith("openai")
|
58 |
self.max_length = max_length
|
59 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
60 |
|
|
|
43 |
device: torch.device | int,
|
44 |
quantization_dtype: str | None = None,
|
45 |
offloading_device: torch.device | int | None = torch.device("cpu"),
|
46 |
+
is_clip: bool = False,
|
47 |
**hf_kwargs,
|
48 |
):
|
49 |
super().__init__()
|
|
|
55 |
self.device = (
|
56 |
device if isinstance(device, torch.device) else torch.device(device)
|
57 |
)
|
58 |
+
self.is_clip = version.startswith("openai") or is_clip
|
59 |
self.max_length = max_length
|
60 |
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
61 |
|
modules/flux_model.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
from collections import namedtuple
|
2 |
import os
|
3 |
-
from
|
|
|
|
|
4 |
import torch
|
|
|
5 |
|
6 |
if TYPE_CHECKING:
|
|
|
7 |
from util import ModelSpec
|
8 |
-
|
9 |
DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
|
10 |
torch.backends.cuda.matmul.allow_tf32 = True
|
11 |
torch.backends.cudnn.allow_tf32 = True
|
@@ -14,8 +16,8 @@ torch.backends.cudnn.benchmark_limit = 20
|
|
14 |
torch.set_float32_matmul_precision("high")
|
15 |
import math
|
16 |
|
17 |
-
from torch import Tensor, nn
|
18 |
from pydantic import BaseModel
|
|
|
19 |
from torch.nn import functional as F
|
20 |
|
21 |
|
@@ -345,6 +347,7 @@ class DoubleStreamBlock(nn.Module):
|
|
345 |
self.H = self.num_heads
|
346 |
self.KH = self.K * self.H
|
347 |
self.do_clamp = dtype == torch.float16
|
|
|
348 |
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
349 |
B, L, D = x.shape
|
350 |
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
@@ -512,6 +515,7 @@ class Flux(nn.Module):
|
|
512 |
self.params = config.params
|
513 |
self.in_channels = config.params.in_channels
|
514 |
self.out_channels = self.in_channels
|
|
|
515 |
prequantized_flow = config.prequantized_flow
|
516 |
quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
|
517 |
quantized_modulation = config.quantize_modulation and prequantized_flow
|
@@ -614,6 +618,57 @@ class Flux(nn.Module):
|
|
614 |
|
615 |
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
617 |
def forward(
|
618 |
self,
|
619 |
img: Tensor,
|
@@ -664,9 +719,10 @@ class Flux(nn.Module):
|
|
664 |
def from_pretrained(
|
665 |
cls: "Flux", path: str, dtype: torch.dtype = torch.float16
|
666 |
) -> "Flux":
|
667 |
-
from util import load_config_from_path
|
668 |
from safetensors.torch import load_file
|
669 |
|
|
|
|
|
670 |
config = load_config_from_path(path)
|
671 |
with torch.device("meta"):
|
672 |
klass = cls(config=config, dtype=dtype)
|
|
|
|
|
1 |
import os
|
2 |
+
from collections import namedtuple
|
3 |
+
from typing import TYPE_CHECKING, List
|
4 |
+
|
5 |
import torch
|
6 |
+
from loguru import logger
|
7 |
|
8 |
if TYPE_CHECKING:
|
9 |
+
from lora_loading import LoraWeights
|
10 |
from util import ModelSpec
|
|
|
11 |
DISABLE_COMPILE = os.getenv("DISABLE_COMPILE", "0") == "1"
|
12 |
torch.backends.cuda.matmul.allow_tf32 = True
|
13 |
torch.backends.cudnn.allow_tf32 = True
|
|
|
16 |
torch.set_float32_matmul_precision("high")
|
17 |
import math
|
18 |
|
|
|
19 |
from pydantic import BaseModel
|
20 |
+
from torch import Tensor, nn
|
21 |
from torch.nn import functional as F
|
22 |
|
23 |
|
|
|
347 |
self.H = self.num_heads
|
348 |
self.KH = self.K * self.H
|
349 |
self.do_clamp = dtype == torch.float16
|
350 |
+
|
351 |
def rearrange_for_norm(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
352 |
B, L, D = x.shape
|
353 |
q, k, v = x.reshape(B, L, self.K, self.H, D // self.KH).permute(2, 0, 3, 1, 4)
|
|
|
515 |
self.params = config.params
|
516 |
self.in_channels = config.params.in_channels
|
517 |
self.out_channels = self.in_channels
|
518 |
+
self.loras: List[LoraWeights] = []
|
519 |
prequantized_flow = config.prequantized_flow
|
520 |
quantized_embedders = config.quantize_flow_embedder_layers and prequantized_flow
|
521 |
quantized_modulation = config.quantize_modulation and prequantized_flow
|
|
|
618 |
|
619 |
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
620 |
|
621 |
+
def get_lora(self, identifier: str):
|
622 |
+
for lora in self.loras:
|
623 |
+
if lora.path == identifier or lora.name == identifier:
|
624 |
+
return lora
|
625 |
+
|
626 |
+
def has_lora(self, identifier: str):
|
627 |
+
for lora in self.loras:
|
628 |
+
if lora.path == identifier or lora.name == identifier:
|
629 |
+
return True
|
630 |
+
|
631 |
+
def load_lora(self, path: str, scale: float, name: str = None):
|
632 |
+
from lora_loading import (
|
633 |
+
LoraWeights,
|
634 |
+
apply_lora_to_model,
|
635 |
+
remove_lora_from_module,
|
636 |
+
)
|
637 |
+
|
638 |
+
if self.has_lora(path):
|
639 |
+
lora = self.get_lora(path)
|
640 |
+
if lora.scale == scale:
|
641 |
+
logger.warning(
|
642 |
+
f"Lora {lora.name} already loaded with same scale - ignoring!"
|
643 |
+
)
|
644 |
+
else:
|
645 |
+
remove_lora_from_module(self, lora, lora.scale)
|
646 |
+
apply_lora_to_model(self, lora, scale)
|
647 |
+
for idx, lora_ in enumerate(self.loras):
|
648 |
+
if lora_.path == lora.path:
|
649 |
+
self.loras[idx].scale = scale
|
650 |
+
break
|
651 |
+
else:
|
652 |
+
_, lora = apply_lora_to_model(self, path, scale, return_lora_resolved=True)
|
653 |
+
self.loras.append(LoraWeights(lora, path, name, scale))
|
654 |
+
|
655 |
+
def unload_lora(self, path_or_identifier: str):
|
656 |
+
from lora_loading import remove_lora_from_module
|
657 |
+
|
658 |
+
removed = False
|
659 |
+
for idx, lora_ in enumerate(list(self.loras)):
|
660 |
+
if lora_.path == path_or_identifier or lora_.name == path_or_identifier:
|
661 |
+
remove_lora_from_module(self, lora_.weights, lora_.scale)
|
662 |
+
self.loras.pop(idx)
|
663 |
+
removed = True
|
664 |
+
break
|
665 |
+
if not removed:
|
666 |
+
logger.warning(
|
667 |
+
f"Couldn't remove lora {path_or_identifier} as it wasn't found fused to the model!"
|
668 |
+
)
|
669 |
+
else:
|
670 |
+
logger.info("Successfully removed lora from module.")
|
671 |
+
|
672 |
def forward(
|
673 |
self,
|
674 |
img: Tensor,
|
|
|
719 |
def from_pretrained(
|
720 |
cls: "Flux", path: str, dtype: torch.dtype = torch.float16
|
721 |
) -> "Flux":
|
|
|
722 |
from safetensors.torch import load_file
|
723 |
|
724 |
+
from util import load_config_from_path
|
725 |
+
|
726 |
config = load_config_from_path(path)
|
727 |
with torch.device("meta"):
|
728 |
klass = cls(config=config, dtype=dtype)
|
util.py
CHANGED
@@ -34,11 +34,14 @@ class QuantizationDtype(StrEnum):
|
|
34 |
bfloat16 = "bfloat16"
|
35 |
float16 = "float16"
|
36 |
|
|
|
37 |
class ModelSpec(BaseModel):
|
38 |
version: ModelVersion
|
39 |
params: FluxParams
|
40 |
ae_params: AutoEncoderParams
|
41 |
ckpt_path: str | None
|
|
|
|
|
42 |
ae_path: str | None
|
43 |
repo_id: str | None
|
44 |
repo_flow: str | None
|
@@ -255,10 +258,11 @@ def load_flow_model(config: ModelSpec) -> Flux:
|
|
255 |
|
256 |
def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
|
257 |
clip = HFEmbedder(
|
258 |
-
|
259 |
max_length=77,
|
260 |
torch_dtype=into_dtype(config.text_enc_dtype),
|
261 |
device=into_device(config.text_enc_device).index or 0,
|
|
|
262 |
quantization_dtype=config.clip_quantization_dtype,
|
263 |
)
|
264 |
t5 = HFEmbedder(
|
|
|
34 |
bfloat16 = "bfloat16"
|
35 |
float16 = "float16"
|
36 |
|
37 |
+
|
38 |
class ModelSpec(BaseModel):
|
39 |
version: ModelVersion
|
40 |
params: FluxParams
|
41 |
ae_params: AutoEncoderParams
|
42 |
ckpt_path: str | None
|
43 |
+
# Add option to pass in custom clip model
|
44 |
+
clip_path: str | None = "openai/clip-vit-large-patch14"
|
45 |
ae_path: str | None
|
46 |
repo_id: str | None
|
47 |
repo_flow: str | None
|
|
|
258 |
|
259 |
def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]:
|
260 |
clip = HFEmbedder(
|
261 |
+
config.clip_path,
|
262 |
max_length=77,
|
263 |
torch_dtype=into_dtype(config.text_enc_dtype),
|
264 |
device=into_device(config.text_enc_device).index or 0,
|
265 |
+
is_clip=True,
|
266 |
quantization_dtype=config.clip_quantization_dtype,
|
267 |
)
|
268 |
t5 = HFEmbedder(
|