|
import torch |
|
import gc |
|
from torch import nn |
|
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module |
|
import bitsandbytes as bnb |
|
|
|
def torch_gc(): |
|
|
|
if torch.cuda.is_available(): |
|
with torch.cuda.device('cuda'): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
gc.collect() |
|
|
|
def restart_cpu_offload(pipe, load_mode): |
|
|
|
|
|
optionally_disable_offloading(pipe) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
|
|
def optionally_disable_offloading(_pipeline): |
|
|
|
""" |
|
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU. |
|
|
|
Args: |
|
_pipeline (`DiffusionPipeline`): |
|
The pipeline to disable offloading for. |
|
|
|
Returns: |
|
tuple: |
|
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. |
|
""" |
|
is_model_cpu_offload = False |
|
is_sequential_cpu_offload = False |
|
print( |
|
fr"Restarting CPU Offloading for {_pipeline.unet_name}..." |
|
) |
|
if _pipeline is not None: |
|
for _, component in _pipeline.components.items(): |
|
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): |
|
if not is_model_cpu_offload: |
|
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) |
|
if not is_sequential_cpu_offload: |
|
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook) |
|
|
|
|
|
remove_hook_from_module(component, recurse=True) |
|
|
|
return (is_model_cpu_offload, is_sequential_cpu_offload) |
|
|
|
def quantize_4bit(module): |
|
for name, child in module.named_children(): |
|
if isinstance(child, torch.nn.Linear): |
|
in_features = child.in_features |
|
out_features = child.out_features |
|
device = child.weight.data.device |
|
|
|
|
|
has_bias = True if child.bias is not None else False |
|
|
|
|
|
|
|
|
|
bnb_4bit_compute_dtype = torch.float16 |
|
quant_type = "nf4" |
|
|
|
new_layer = bnb.nn.Linear4bit( |
|
in_features, |
|
out_features, |
|
bias=has_bias, |
|
compute_dtype=bnb_4bit_compute_dtype, |
|
quant_type=quant_type, |
|
) |
|
|
|
new_layer.load_state_dict(child.state_dict()) |
|
new_layer = new_layer.to(device) |
|
|
|
|
|
setattr(module, name, new_layer) |
|
else: |
|
|
|
quantize_4bit(child) |