EcoDiff / utils.py
JJRWTH's picture
update hf hub
bb7ee19
raw
history blame
10.8 kB
import os
import pickle
from copy import deepcopy
from typing import Optional
import torch
from diffusers.models.activations import GEGLU, GELU
from cross_attn_hook import CrossAttentionExtractionHook
from ffn_hooker import FeedForwardHooker
from norm_attn_hook import NormHooker
# create dummy module for skip connection
class SkipConnection(torch.nn.Module):
def __init__(self):
super(SkipConnection, self).__init__()
def forward(*args, **kwargs):
return args[1]
def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
total_num_lambs = 0
num_activate_lambs = 0
binary = getattr(
hooker, "binary", None
) # if binary is not present, it will return None for ff_hooks
for lamb in hooker.lambs:
total_num_lambs += lamb.size(0)
if binary:
assert threshold is None, "threshold should be None for binary mask"
num_activate_lambs += lamb.sum().item()
else:
assert (
threshold is not None
), "threshold must be provided for non-binary mask"
num_activate_lambs += (lamb >= threshold).sum().item()
return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
def create_pipeline(
pipe,
model_id,
device,
torch_dtype,
save_pt=None,
lambda_threshold: float = 1,
binary=True,
epsilon=0.0,
masking="binary",
attn_name="attn",
return_hooker=False,
scope=None,
ratio=None,
):
"""
create the pipeline and optionally load the saved mask
"""
pipe.to(device)
pipe.vae.requires_grad_(False)
if hasattr(pipe, "unet"):
pipe.unet.requires_grad_(False)
else:
pipe.transformer.requires_grad_(False)
if save_pt:
# TODO should merge all the hooks checkpoint into one
if "ff.pt" in save_pt or "attn.pt" in save_pt:
save_pts = get_save_pts(save_pt)
cross_attn_hooker = CrossAttentionExtractionHook(
pipe,
model_name=model_id,
regex=".*",
dtype=torch_dtype,
head_num_filter=1,
masking=masking, # need to change to binary during inference
dst=save_pts["attn"],
epsilon=epsilon,
attn_name=attn_name,
binary=binary,
)
cross_attn_hooker.add_hooks(init_value=1)
ff_hooker = FeedForwardHooker(
pipe,
regex=".*",
dtype=torch_dtype,
masking=masking,
dst=save_pts["ff"],
epsilon=epsilon,
binary=binary,
)
ff_hooker.add_hooks(init_value=1)
if os.path.exists(save_pts["norm"]):
norm_hooker = NormHooker(
pipe,
regex=".*",
dtype=torch_dtype,
masking=masking,
dst=save_pts["norm"],
epsilon=epsilon,
binary=binary,
)
norm_hooker.add_hooks(init_value=1)
else:
norm_hooker = None
_ = pipe("abc", num_inference_steps=1)
cross_attn_hooker.load(device=device, threshold=lambda_threshold)
ff_hooker.load(device=device, threshold=lambda_threshold)
if norm_hooker:
norm_hooker.load(device=device, threshold=lambda_threshold)
if scope == "local" or scope == "global":
if isinstance(ratio, float):
attn_hooker_ratio = ratio
ff_hooker_ratio = ratio
else:
attn_hooker_ratio, ff_hooker_ratio = ratio[0], ratio[1]
if norm_hooker:
if len(ratio) < 3:
raise ValueError("Need to provide ratio for norm layer")
norm_hooker_ratio = ratio[2]
cross_attn_hooker.binarize(scope, attn_hooker_ratio)
ff_hooker.binarize(scope, ff_hooker_ratio)
if norm_hooker:
norm_hooker.binarize(scope, norm_hooker_ratio)
hookers = [cross_attn_hooker, ff_hooker]
if norm_hooker:
hookers.append(norm_hooker)
if return_hooker:
return pipe, hookers
else:
return pipe
def linear_layer_pruning(module, lamb):
heads_to_keep = torch.nonzero(lamb).squeeze()
if len(heads_to_keep.shape) == 0:
# if only one head is kept, or none
heads_to_keep = heads_to_keep.unsqueeze(0)
modules_to_remove = [module.to_k, module.to_q, module.to_v]
new_heads = int(lamb.sum().item())
if new_heads == 0:
return SkipConnection()
for module_to_remove in modules_to_remove:
# get head dimension
inner_dim = module_to_remove.out_features // module.heads
# place holder for the rows to keep
rows_to_keep = torch.zeros(
module_to_remove.out_features,
dtype=torch.bool,
device=module_to_remove.weight.device,
)
for idx in heads_to_keep:
rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
# overwrite the inner projection with masked projection
module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
if module_to_remove.bias is not None:
module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
module_to_remove.out_features = int(sum(rows_to_keep).item())
# Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
# with column masking, dim 1
if getattr(module, "to_out", None) is not None:
module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
module.to_out[0].in_features = int(sum(rows_to_keep).item())
# update parameters in the attention module
module.inner_dim = module.inner_dim // module.heads * new_heads
try:
module.query_dim = module.query_dim // module.heads * new_heads
module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
except:
pass
module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
module.heads = new_heads
return module
def ffn_linear_layer_pruning(module, lamb):
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
if isinstance(module.net[0], GELU):
# linear layer weight remove before activation
module.net[0].proj.weight.data = module.net[0].proj.weight.data[
lambda_to_keep, :
]
module.net[0].proj.out_features = num_lambda
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
update_act = GELU(module.net[0].proj.in_features, num_lambda)
update_act.proj = module.net[0].proj
module.net[0] = update_act
elif isinstance(module.net[0], GEGLU):
output_feature = module.net[0].proj.out_features
module.net[0].proj.weight.data = torch.cat(
[
module.net[0].proj.weight.data[: output_feature // 2, :][
lambda_to_keep, :
],
module.net[0].proj.weight.data[output_feature // 2 :][
lambda_to_keep, :
],
],
dim=0,
)
module.net[0].proj.out_features = num_lambda * 2
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = torch.cat(
[
module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
]
)
update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
update_act.proj = module.net[0].proj
module.net[0] = update_act
# proj weight after activation
module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
module.net[2].in_features = num_lambda
return module
# create SparsityLinear module
class SparsityLinear(torch.nn.Module):
def __init__(self, in_features, out_features, lambda_to_keep, num_lambda):
super(SparsityLinear, self).__init__()
self.linear = torch.nn.Linear(in_features, num_lambda)
self.out_features = out_features
self.lambda_to_keep = lambda_to_keep
def forward(self, x):
x = self.linear(x)
output = torch.zeros(
x.size(0), self.out_features, device=x.device, dtype=x.dtype
)
output[:, self.lambda_to_keep] = x
return output
def norm_layer_pruning(module, lamb):
"""
Pruning the layer normalization layer for FLUX model
"""
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
# get num_features
in_features = module.linear.in_features
out_features = module.linear.out_features
linear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda)
linear.linear.weight.data = module.linear.weight.data[lambda_to_keep]
linear.linear.bias.data = module.linear.bias.data[lambda_to_keep]
module.linear = linear
return module
def get_save_pts(save_pt):
if "ff.pt" in save_pt:
ff_save_pt = deepcopy(save_pt) # avoid in-place operation
attn_save_pt = save_pt.split(os.sep)
attn_save_pt[-1] = attn_save_pt[-1].replace("ff", "attn")
attn_save_pt_output = os.sep.join(attn_save_pt)
attn_save_pt[-1] = attn_save_pt[-1].replace("attn", "norm")
norm_save_pt = os.sep.join(attn_save_pt)
return {
"ff": ff_save_pt,
"attn": attn_save_pt_output,
"norm": norm_save_pt,
}
else:
attn_save_pt = deepcopy(save_pt)
ff_save_pt = save_pt.split(os.sep)
ff_save_pt[-1] = ff_save_pt[-1].replace("attn", "ff")
ff_save_pt_output = os.sep.join(ff_save_pt)
ff_save_pt[-1] = ff_save_pt[-1].replace("ff", "norm")
norm_save_pt = os.sep.join(attn_save_pt)
return {
"ff": ff_save_pt_output,
"attn": attn_save_pt,
"norm": norm_save_pt,
}
def save_img(pipe, g_cpu, steps, prompt, save_path):
image = pipe(prompt, generator=g_cpu, num_inference_steps=steps)
image["images"][0].save(save_path)