Ensure repo only accesses CublasLinear lazily
Browse files- float8_quantize.py +5 -1
- lora_loading.py +5 -4
float8_quantize.py
CHANGED
@@ -447,7 +447,11 @@ def quantize_flow_transformer_and_dispatch_float8(
|
|
447 |
quantize_modulation=quantize_modulation,
|
448 |
)
|
449 |
torch.cuda.empty_cache()
|
450 |
-
if
|
|
|
|
|
|
|
|
|
451 |
swap_to_cublaslinear(flow_model)
|
452 |
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
|
453 |
logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
|
|
|
447 |
quantize_modulation=quantize_modulation,
|
448 |
)
|
449 |
torch.cuda.empty_cache()
|
450 |
+
if (
|
451 |
+
swap_linears_with_cublaslinear
|
452 |
+
and flow_dtype == torch.float16
|
453 |
+
and isinstance(CublasLinear, type(torch.nn.Linear))
|
454 |
+
):
|
455 |
swap_to_cublaslinear(flow_model)
|
456 |
elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
|
457 |
logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
|
lora_loading.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
import torch
|
2 |
-
from cublas_ops import CublasLinear
|
3 |
from loguru import logger
|
4 |
from safetensors.torch import load_file
|
5 |
from tqdm import tqdm
|
6 |
|
|
|
|
|
|
|
|
|
7 |
from float8_quantize import F8Linear
|
8 |
from modules.flux_model import Flux
|
9 |
|
@@ -383,7 +386,7 @@ def apply_lora_weight_to_module(
|
|
383 |
|
384 |
|
385 |
@torch.inference_mode()
|
386 |
-
def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0):
|
387 |
has_guidance = model.params.guidance_embed
|
388 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
389 |
lora_weights = load_file(lora_path)
|
@@ -408,8 +411,6 @@ def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0):
|
|
408 |
]
|
409 |
logger.debug("Keys extracted")
|
410 |
keys_without_ab = list(set(keys_without_ab))
|
411 |
-
if len(keys_without_ab) > 0:
|
412 |
-
logger.warning("Missing unconverted state dict keys!", len(keys_without_ab))
|
413 |
|
414 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
415 |
module = get_module_for_key(key, model)
|
|
|
1 |
import torch
|
|
|
2 |
from loguru import logger
|
3 |
from safetensors.torch import load_file
|
4 |
from tqdm import tqdm
|
5 |
|
6 |
+
try:
|
7 |
+
from cublas_ops import CublasLinear
|
8 |
+
except Exception as e:
|
9 |
+
CublasLinear = type(None)
|
10 |
from float8_quantize import F8Linear
|
11 |
from modules.flux_model import Flux
|
12 |
|
|
|
386 |
|
387 |
|
388 |
@torch.inference_mode()
|
389 |
+
def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0) -> Flux:
|
390 |
has_guidance = model.params.guidance_embed
|
391 |
logger.info(f"Loading LoRA weights for {lora_path}")
|
392 |
lora_weights = load_file(lora_path)
|
|
|
411 |
]
|
412 |
logger.debug("Keys extracted")
|
413 |
keys_without_ab = list(set(keys_without_ab))
|
|
|
|
|
414 |
|
415 |
for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
|
416 |
module = get_module_for_key(key, model)
|