Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,774 Bytes
bb7ee19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import os
import pickle
from copy import deepcopy
from typing import Optional
import torch
from diffusers.models.activations import GEGLU, GELU
from cross_attn_hook import CrossAttentionExtractionHook
from ffn_hooker import FeedForwardHooker
from norm_attn_hook import NormHooker
# create dummy module for skip connection
class SkipConnection(torch.nn.Module):
def __init__(self):
super(SkipConnection, self).__init__()
def forward(*args, **kwargs):
return args[1]
def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
total_num_lambs = 0
num_activate_lambs = 0
binary = getattr(
hooker, "binary", None
) # if binary is not present, it will return None for ff_hooks
for lamb in hooker.lambs:
total_num_lambs += lamb.size(0)
if binary:
assert threshold is None, "threshold should be None for binary mask"
num_activate_lambs += lamb.sum().item()
else:
assert (
threshold is not None
), "threshold must be provided for non-binary mask"
num_activate_lambs += (lamb >= threshold).sum().item()
return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
def create_pipeline(
pipe,
model_id,
device,
torch_dtype,
save_pt=None,
lambda_threshold: float = 1,
binary=True,
epsilon=0.0,
masking="binary",
attn_name="attn",
return_hooker=False,
scope=None,
ratio=None,
):
"""
create the pipeline and optionally load the saved mask
"""
pipe.to(device)
pipe.vae.requires_grad_(False)
if hasattr(pipe, "unet"):
pipe.unet.requires_grad_(False)
else:
pipe.transformer.requires_grad_(False)
if save_pt:
# TODO should merge all the hooks checkpoint into one
if "ff.pt" in save_pt or "attn.pt" in save_pt:
save_pts = get_save_pts(save_pt)
cross_attn_hooker = CrossAttentionExtractionHook(
pipe,
model_name=model_id,
regex=".*",
dtype=torch_dtype,
head_num_filter=1,
masking=masking, # need to change to binary during inference
dst=save_pts["attn"],
epsilon=epsilon,
attn_name=attn_name,
binary=binary,
)
cross_attn_hooker.add_hooks(init_value=1)
ff_hooker = FeedForwardHooker(
pipe,
regex=".*",
dtype=torch_dtype,
masking=masking,
dst=save_pts["ff"],
epsilon=epsilon,
binary=binary,
)
ff_hooker.add_hooks(init_value=1)
if os.path.exists(save_pts["norm"]):
norm_hooker = NormHooker(
pipe,
regex=".*",
dtype=torch_dtype,
masking=masking,
dst=save_pts["norm"],
epsilon=epsilon,
binary=binary,
)
norm_hooker.add_hooks(init_value=1)
else:
norm_hooker = None
_ = pipe("abc", num_inference_steps=1)
cross_attn_hooker.load(device=device, threshold=lambda_threshold)
ff_hooker.load(device=device, threshold=lambda_threshold)
if norm_hooker:
norm_hooker.load(device=device, threshold=lambda_threshold)
if scope == "local" or scope == "global":
if isinstance(ratio, float):
attn_hooker_ratio = ratio
ff_hooker_ratio = ratio
else:
attn_hooker_ratio, ff_hooker_ratio = ratio[0], ratio[1]
if norm_hooker:
if len(ratio) < 3:
raise ValueError("Need to provide ratio for norm layer")
norm_hooker_ratio = ratio[2]
cross_attn_hooker.binarize(scope, attn_hooker_ratio)
ff_hooker.binarize(scope, ff_hooker_ratio)
if norm_hooker:
norm_hooker.binarize(scope, norm_hooker_ratio)
hookers = [cross_attn_hooker, ff_hooker]
if norm_hooker:
hookers.append(norm_hooker)
if return_hooker:
return pipe, hookers
else:
return pipe
def linear_layer_pruning(module, lamb):
heads_to_keep = torch.nonzero(lamb).squeeze()
if len(heads_to_keep.shape) == 0:
# if only one head is kept, or none
heads_to_keep = heads_to_keep.unsqueeze(0)
modules_to_remove = [module.to_k, module.to_q, module.to_v]
new_heads = int(lamb.sum().item())
if new_heads == 0:
return SkipConnection()
for module_to_remove in modules_to_remove:
# get head dimension
inner_dim = module_to_remove.out_features // module.heads
# place holder for the rows to keep
rows_to_keep = torch.zeros(
module_to_remove.out_features,
dtype=torch.bool,
device=module_to_remove.weight.device,
)
for idx in heads_to_keep:
rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
# overwrite the inner projection with masked projection
module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
if module_to_remove.bias is not None:
module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
module_to_remove.out_features = int(sum(rows_to_keep).item())
# Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
# with column masking, dim 1
if getattr(module, "to_out", None) is not None:
module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
module.to_out[0].in_features = int(sum(rows_to_keep).item())
# update parameters in the attention module
module.inner_dim = module.inner_dim // module.heads * new_heads
try:
module.query_dim = module.query_dim // module.heads * new_heads
module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
except:
pass
module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
module.heads = new_heads
return module
def ffn_linear_layer_pruning(module, lamb):
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
if isinstance(module.net[0], GELU):
# linear layer weight remove before activation
module.net[0].proj.weight.data = module.net[0].proj.weight.data[
lambda_to_keep, :
]
module.net[0].proj.out_features = num_lambda
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
update_act = GELU(module.net[0].proj.in_features, num_lambda)
update_act.proj = module.net[0].proj
module.net[0] = update_act
elif isinstance(module.net[0], GEGLU):
output_feature = module.net[0].proj.out_features
module.net[0].proj.weight.data = torch.cat(
[
module.net[0].proj.weight.data[: output_feature // 2, :][
lambda_to_keep, :
],
module.net[0].proj.weight.data[output_feature // 2 :][
lambda_to_keep, :
],
],
dim=0,
)
module.net[0].proj.out_features = num_lambda * 2
if module.net[0].proj.bias is not None:
module.net[0].proj.bias.data = torch.cat(
[
module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
]
)
update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
update_act.proj = module.net[0].proj
module.net[0] = update_act
# proj weight after activation
module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
module.net[2].in_features = num_lambda
return module
# create SparsityLinear module
class SparsityLinear(torch.nn.Module):
def __init__(self, in_features, out_features, lambda_to_keep, num_lambda):
super(SparsityLinear, self).__init__()
self.linear = torch.nn.Linear(in_features, num_lambda)
self.out_features = out_features
self.lambda_to_keep = lambda_to_keep
def forward(self, x):
x = self.linear(x)
output = torch.zeros(
x.size(0), self.out_features, device=x.device, dtype=x.dtype
)
output[:, self.lambda_to_keep] = x
return output
def norm_layer_pruning(module, lamb):
"""
Pruning the layer normalization layer for FLUX model
"""
lambda_to_keep = torch.nonzero(lamb).squeeze()
if len(lambda_to_keep) == 0:
return SkipConnection()
num_lambda = len(lambda_to_keep)
# get num_features
in_features = module.linear.in_features
out_features = module.linear.out_features
linear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda)
linear.linear.weight.data = module.linear.weight.data[lambda_to_keep]
linear.linear.bias.data = module.linear.bias.data[lambda_to_keep]
module.linear = linear
return module
def get_save_pts(save_pt):
if "ff.pt" in save_pt:
ff_save_pt = deepcopy(save_pt) # avoid in-place operation
attn_save_pt = save_pt.split(os.sep)
attn_save_pt[-1] = attn_save_pt[-1].replace("ff", "attn")
attn_save_pt_output = os.sep.join(attn_save_pt)
attn_save_pt[-1] = attn_save_pt[-1].replace("attn", "norm")
norm_save_pt = os.sep.join(attn_save_pt)
return {
"ff": ff_save_pt,
"attn": attn_save_pt_output,
"norm": norm_save_pt,
}
else:
attn_save_pt = deepcopy(save_pt)
ff_save_pt = save_pt.split(os.sep)
ff_save_pt[-1] = ff_save_pt[-1].replace("attn", "ff")
ff_save_pt_output = os.sep.join(ff_save_pt)
ff_save_pt[-1] = ff_save_pt[-1].replace("ff", "norm")
norm_save_pt = os.sep.join(attn_save_pt)
return {
"ff": ff_save_pt_output,
"attn": attn_save_pt,
"norm": norm_save_pt,
}
def save_img(pipe, g_cpu, steps, prompt, save_path):
image = pipe(prompt, generator=g_cpu, num_inference_steps=steps)
image["images"][0].save(save_path) |