Safetensors
aredden commited on
Commit
37bd8c1
·
1 Parent(s): 25ae92b

Dynamic swap with cublas linear / optional improved precision with vram drawback

Browse files
Files changed (4) hide show
  1. float8_quantize.py +75 -8
  2. flux_pipeline.py +7 -1
  3. modules/flux_model.py +2 -7
  4. util.py +9 -0
float8_quantize.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from torchao.float8.float8_utils import (
@@ -10,7 +11,8 @@ import math
10
  from torch.compiler import is_compiling
11
  from torch import __version__
12
  from torch.version import cuda
13
- from typing import TypeVar
 
14
 
15
  IS_TORCH_2_4 = __version__ < (2, 4, 9)
16
  LT_TORCH_2_4 = __version__ < (2, 4)
@@ -42,7 +44,7 @@ class F8Linear(nn.Module):
42
  float8_dtype=torch.float8_e4m3fn,
43
  float_weight: torch.Tensor = None,
44
  float_bias: torch.Tensor = None,
45
- num_scale_trials: int = 24,
46
  input_float8_dtype=torch.float8_e5m2,
47
  ) -> None:
48
  super().__init__()
@@ -183,6 +185,11 @@ class F8Linear(nn.Module):
183
  1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
184
  )
185
 
 
 
 
 
 
186
  def quantize_input(self, x: torch.Tensor):
187
  if self.input_scale_initialized:
188
  return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
@@ -279,10 +286,12 @@ class F8Linear(nn.Module):
279
  return f8_lin
280
 
281
 
 
282
  def recursive_swap_linears(
283
  model: nn.Module,
284
  float8_dtype=torch.float8_e4m3fn,
285
  input_float8_dtype=torch.float8_e5m2,
 
286
  ) -> None:
287
  """
288
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
@@ -300,6 +309,8 @@ def recursive_swap_linears(
300
  all linear layers in the model will be using 8-bit quantization.
301
  """
302
  for name, child in model.named_children():
 
 
303
  if isinstance(child, nn.Linear) and not isinstance(
304
  child, (F8Linear, CublasLinear)
305
  ):
@@ -315,7 +326,35 @@ def recursive_swap_linears(
315
  )
316
  del child
317
  else:
318
- recursive_swap_linears(child)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
 
321
  @torch.inference_mode()
@@ -325,6 +364,10 @@ def quantize_flow_transformer_and_dispatch_float8(
325
  float8_dtype=torch.float8_e4m3fn,
326
  input_float8_dtype=torch.float8_e5m2,
327
  offload_flow=False,
 
 
 
 
328
  ) -> nn.Module:
329
  """
330
  Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
@@ -334,19 +377,36 @@ def quantize_flow_transformer_and_dispatch_float8(
334
  Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
335
 
336
  After dispatching, if offload_flow is True, offloads the model to cpu.
 
 
 
 
 
 
 
 
 
 
 
337
  """
338
  for module in flow_model.double_blocks:
339
  module.to(device)
340
  module.eval()
341
  recursive_swap_linears(
342
- module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
 
 
 
343
  )
344
  torch.cuda.empty_cache()
345
  for module in flow_model.single_blocks:
346
  module.to(device)
347
  module.eval()
348
  recursive_swap_linears(
349
- module, float8_dtype=float8_dtype, input_float8_dtype=input_float8_dtype
 
 
 
350
  )
351
  torch.cuda.empty_cache()
352
  to_gpu_extras = [
@@ -364,8 +424,10 @@ def quantize_flow_transformer_and_dispatch_float8(
364
  continue
365
  m_extra.to(device)
366
  m_extra.eval()
367
- if isinstance(m_extra, nn.Linear) and not isinstance(
368
- m_extra, (F8Linear, CublasLinear)
 
 
369
  ):
370
  setattr(
371
  flow_model,
@@ -377,13 +439,18 @@ def quantize_flow_transformer_and_dispatch_float8(
377
  ),
378
  )
379
  del m_extra
380
- elif module != "final_layer":
381
  recursive_swap_linears(
382
  m_extra,
383
  float8_dtype=float8_dtype,
384
  input_float8_dtype=input_float8_dtype,
 
385
  )
386
  torch.cuda.empty_cache()
 
 
 
 
387
  if offload_flow:
388
  flow_model.to("cpu")
389
  torch.cuda.empty_cache()
 
1
+ from loguru import logger
2
  import torch
3
  import torch.nn as nn
4
  from torchao.float8.float8_utils import (
 
11
  from torch.compiler import is_compiling
12
  from torch import __version__
13
  from torch.version import cuda
14
+
15
+ from modules.flux_model import Modulation
16
 
17
  IS_TORCH_2_4 = __version__ < (2, 4, 9)
18
  LT_TORCH_2_4 = __version__ < (2, 4)
 
44
  float8_dtype=torch.float8_e4m3fn,
45
  float_weight: torch.Tensor = None,
46
  float_bias: torch.Tensor = None,
47
+ num_scale_trials: int = 12,
48
  input_float8_dtype=torch.float8_e5m2,
49
  ) -> None:
50
  super().__init__()
 
185
  1, dtype=self.weight.dtype, device=self.weight.device, requires_grad=False
186
  )
187
 
188
+ def set_weight_tensor(self, tensor: torch.Tensor):
189
+ self.weight.data = tensor
190
+ self.weight_initialized = False
191
+ self.quantize_weight()
192
+
193
  def quantize_input(self, x: torch.Tensor):
194
  if self.input_scale_initialized:
195
  return to_fp8_saturated(x * self.input_scale, self.input_float8_dtype)
 
286
  return f8_lin
287
 
288
 
289
+ @torch.inference_mode()
290
  def recursive_swap_linears(
291
  model: nn.Module,
292
  float8_dtype=torch.float8_e4m3fn,
293
  input_float8_dtype=torch.float8_e5m2,
294
+ quantize_modulation: bool = True,
295
  ) -> None:
296
  """
297
  Recursively swaps all nn.Linear modules in the given model with F8Linear modules.
 
309
  all linear layers in the model will be using 8-bit quantization.
310
  """
311
  for name, child in model.named_children():
312
+ if isinstance(child, Modulation) and not quantize_modulation:
313
+ continue
314
  if isinstance(child, nn.Linear) and not isinstance(
315
  child, (F8Linear, CublasLinear)
316
  ):
 
326
  )
327
  del child
328
  else:
329
+ recursive_swap_linears(
330
+ child,
331
+ float8_dtype=float8_dtype,
332
+ input_float8_dtype=input_float8_dtype,
333
+ quantize_modulation=quantize_modulation,
334
+ )
335
+
336
+
337
+ @torch.inference_mode()
338
+ def swap_to_cublaslinear(model: nn.Module):
339
+ if not isinstance(CublasLinear, torch.nn.Module):
340
+ return
341
+ for name, child in model.named_children():
342
+ if isinstance(child, nn.Linear) and not isinstance(
343
+ child, (F8Linear, CublasLinear)
344
+ ):
345
+ cublas_lin = CublasLinear(
346
+ child.in_features,
347
+ child.out_features,
348
+ bias=child.bias is not None,
349
+ dtype=child.weight.dtype,
350
+ device=child.weight.device,
351
+ )
352
+ cublas_lin.weight.data = child.weight.clone().detach()
353
+ cublas_lin.bias.data = child.bias.clone().detach()
354
+ setattr(model, name, cublas_lin)
355
+ del child
356
+ else:
357
+ swap_to_cublaslinear(child)
358
 
359
 
360
  @torch.inference_mode()
 
364
  float8_dtype=torch.float8_e4m3fn,
365
  input_float8_dtype=torch.float8_e5m2,
366
  offload_flow=False,
367
+ swap_linears_with_cublaslinear=True,
368
+ flow_dtype=torch.float16,
369
+ quantize_modulation: bool = True,
370
+ quantize_flow_embedder_layers: bool = True,
371
  ) -> nn.Module:
372
  """
373
  Quantize the flux flow transformer model (original BFL codebase version) and dispatch to the given device.
 
377
  Allows for fast dispatch to gpu & quantize without causing OOM on gpus with limited memory.
378
 
379
  After dispatching, if offload_flow is True, offloads the model to cpu.
380
+
381
+ if swap_linears_with_cublaslinear is true, and flow_dtype == torch.float16, then swap all linears with cublaslinears for 2x performance boost on consumer GPUs.
382
+ Otherwise will skip the cublaslinear swap.
383
+
384
+ For added extra precision, you can set quantize_flow_embedder_layers to False,
385
+ this helps maintain the output quality of the flow transformer moreso than fully quantizing,
386
+ at the expense of ~512MB more VRAM usage.
387
+
388
+ For added extra precision, you can set quantize_modulation to False,
389
+ this helps maintain the output quality of the flow transformer moreso than fully quantizing,
390
+ at the expense of ~2GB more VRAM usage, but- has a much higher impact on image quality than the embedder layers.
391
  """
392
  for module in flow_model.double_blocks:
393
  module.to(device)
394
  module.eval()
395
  recursive_swap_linears(
396
+ module,
397
+ float8_dtype=float8_dtype,
398
+ input_float8_dtype=input_float8_dtype,
399
+ quantize_modulation=quantize_modulation,
400
  )
401
  torch.cuda.empty_cache()
402
  for module in flow_model.single_blocks:
403
  module.to(device)
404
  module.eval()
405
  recursive_swap_linears(
406
+ module,
407
+ float8_dtype=float8_dtype,
408
+ input_float8_dtype=input_float8_dtype,
409
+ quantize_modulation=quantize_modulation,
410
  )
411
  torch.cuda.empty_cache()
412
  to_gpu_extras = [
 
424
  continue
425
  m_extra.to(device)
426
  m_extra.eval()
427
+ if (
428
+ isinstance(m_extra, nn.Linear)
429
+ and not isinstance(m_extra, (F8Linear, CublasLinear))
430
+ and quantize_flow_embedder_layers
431
  ):
432
  setattr(
433
  flow_model,
 
439
  ),
440
  )
441
  del m_extra
442
+ elif module != "final_layer" and not quantize_flow_embedder_layers:
443
  recursive_swap_linears(
444
  m_extra,
445
  float8_dtype=float8_dtype,
446
  input_float8_dtype=input_float8_dtype,
447
+ quantize_modulation=quantize_modulation,
448
  )
449
  torch.cuda.empty_cache()
450
+ if swap_linears_with_cublaslinear and flow_dtype == torch.float16:
451
+ swap_to_cublaslinear(flow_model)
452
+ elif swap_linears_with_cublaslinear and flow_dtype != torch.float16:
453
+ logger.warning("Skipping cublas linear swap because flow_dtype is not float16")
454
  if offload_flow:
455
  flow_model.to("cpu")
456
  torch.cuda.empty_cache()
flux_pipeline.py CHANGED
@@ -645,7 +645,13 @@ class FluxPipeline:
645
 
646
  if not config.prequantized_flow:
647
  flow_model = quantize_flow_transformer_and_dispatch_float8(
648
- flow_model, flux_device, offload_flow=config.offload_flow
 
 
 
 
 
 
649
  )
650
  else:
651
  flow_model.eval().requires_grad_(False)
 
645
 
646
  if not config.prequantized_flow:
647
  flow_model = quantize_flow_transformer_and_dispatch_float8(
648
+ flow_model,
649
+ flux_device,
650
+ offload_flow=config.offload_flow,
651
+ swap_linears_with_cublaslinear=flux_dtype == torch.float16,
652
+ flow_dtype=flux_dtype,
653
+ quantize_modulation=config.quantize_modulation,
654
+ quantize_flow_embedder_layers=config.quantize_flow_embedder_layers,
655
  )
656
  else:
657
  flow_model.eval().requires_grad_(False)
modules/flux_model.py CHANGED
@@ -14,11 +14,6 @@ from torch import Tensor, nn
14
  from pydantic import BaseModel
15
  from torch.nn import functional as F
16
 
17
- try:
18
- from cublas_ops import CublasLinear
19
- except ImportError:
20
- CublasLinear = nn.Linear
21
-
22
 
23
  class FluxParams(BaseModel):
24
  in_channels: int
@@ -350,11 +345,11 @@ class LastLayer(nn.Module):
350
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
351
  super().__init__()
352
  self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
353
- self.linear = CublasLinear(
354
  hidden_size, patch_size * patch_size * out_channels, bias=True
355
  )
356
  self.adaLN_modulation = nn.Sequential(
357
- nn.SiLU(), CublasLinear(hidden_size, 2 * hidden_size, bias=True)
358
  )
359
 
360
  def forward(self, x: Tensor, vec: Tensor) -> Tensor:
 
14
  from pydantic import BaseModel
15
  from torch.nn import functional as F
16
 
 
 
 
 
 
17
 
18
  class FluxParams(BaseModel):
19
  in_channels: int
 
345
  def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
346
  super().__init__()
347
  self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
348
+ self.linear = nn.Linear(
349
  hidden_size, patch_size * patch_size * out_channels, bias=True
350
  )
351
  self.adaLN_modulation = nn.Sequential(
352
+ nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
353
  )
354
 
355
  def forward(self, x: Tensor, vec: Tensor) -> Tensor:
util.py CHANGED
@@ -8,12 +8,16 @@ from modules.conditioner import HFEmbedder
8
  from modules.flux_model import Flux, FluxParams
9
  from modules.flux_model_f8 import Flux as FluxF8
10
  from safetensors.torch import load_file as load_sft
 
11
  try:
12
  from enum import StrEnum
13
  except:
14
  from enum import Enum
 
15
  class StrEnum(str, Enum):
16
  pass
 
 
17
  from pydantic import BaseModel, ConfigDict
18
  from loguru import logger
19
 
@@ -61,6 +65,11 @@ class ModelSpec(BaseModel):
61
  offload_flow: bool = False
62
  prequantized_flow: bool = False
63
 
 
 
 
 
 
64
  model_config: ConfigDict = {
65
  "arbitrary_types_allowed": True,
66
  "use_enum_values": True,
 
8
  from modules.flux_model import Flux, FluxParams
9
  from modules.flux_model_f8 import Flux as FluxF8
10
  from safetensors.torch import load_file as load_sft
11
+
12
  try:
13
  from enum import StrEnum
14
  except:
15
  from enum import Enum
16
+
17
  class StrEnum(str, Enum):
18
  pass
19
+
20
+
21
  from pydantic import BaseModel, ConfigDict
22
  from loguru import logger
23
 
 
65
  offload_flow: bool = False
66
  prequantized_flow: bool = False
67
 
68
+ # Improved precision via not quanitzing the modulation linear layers
69
+ quantize_modulation: bool = True
70
+ # Improved precision via not quanitzing the flow embedder layers
71
+ quantize_flow_embedder_layers: bool = True
72
+
73
  model_config: ConfigDict = {
74
  "arbitrary_types_allowed": True,
75
  "use_enum_values": True,