Safetensors
aredden commited on
Commit
00f5d2c
·
1 Parent(s): fee1af5

Ensure repo only accesses CublasLinear lazily

Browse files
Files changed (2) hide show
  1. float8_quantize.py +5 -1
  2. lora_loading.py +5 -4
float8_quantize.py CHANGED
@@ -447,7 +447,11 @@ def quantize_flow_transformer_and_dispatch_float8(
447
  quantize_modulation=quantize_modulation,
448
  )
449
  torch.cuda.empty_cache()
450
- if swap_linears_with_cublaslinear and flow_dtype == torch.float16:
 
 
 
 
451
  swap_to_cublaslinear(flow_model)
452
  elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
453
  logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
 
447
  quantize_modulation=quantize_modulation,
448
  )
449
  torch.cuda.empty_cache()
450
+ if (
451
+ swap_linears_with_cublaslinear
452
+ and flow_dtype == torch.float16
453
+ and isinstance(CublasLinear, type(torch.nn.Linear))
454
+ ):
455
  swap_to_cublaslinear(flow_model)
456
  elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
457
  logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
lora_loading.py CHANGED
@@ -1,9 +1,12 @@
1
  import torch
2
- from cublas_ops import CublasLinear
3
  from loguru import logger
4
  from safetensors.torch import load_file
5
  from tqdm import tqdm
6
 
 
 
 
 
7
  from float8_quantize import F8Linear
8
  from modules.flux_model import Flux
9
 
@@ -383,7 +386,7 @@ def apply_lora_weight_to_module(
383
 
384
 
385
  @torch.inference_mode()
386
- def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0):
387
  has_guidance = model.params.guidance_embed
388
  logger.info(f"Loading LoRA weights for {lora_path}")
389
  lora_weights = load_file(lora_path)
@@ -408,8 +411,6 @@ def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0):
408
  ]
409
  logger.debug("Keys extracted")
410
  keys_without_ab = list(set(keys_without_ab))
411
- if len(keys_without_ab) > 0:
412
- logger.warning("Missing unconverted state dict keys!", len(keys_without_ab))
413
 
414
  for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
415
  module = get_module_for_key(key, model)
 
1
  import torch
 
2
  from loguru import logger
3
  from safetensors.torch import load_file
4
  from tqdm import tqdm
5
 
6
+ try:
7
+ from cublas_ops import CublasLinear
8
+ except Exception as e:
9
+ CublasLinear = type(None)
10
  from float8_quantize import F8Linear
11
  from modules.flux_model import Flux
12
 
 
386
 
387
 
388
  @torch.inference_mode()
389
+ def apply_lora_to_model(model: Flux, lora_path: str, lora_scale: float = 1.0) -> Flux:
390
  has_guidance = model.params.guidance_embed
391
  logger.info(f"Loading LoRA weights for {lora_path}")
392
  lora_weights = load_file(lora_path)
 
411
  ]
412
  logger.debug("Keys extracted")
413
  keys_without_ab = list(set(keys_without_ab))
 
 
414
 
415
  for key in tqdm(keys_without_ab, desc="Applying LoRA", total=len(keys_without_ab)):
416
  module = get_module_for_key(key, model)