zhangyang-0123 commited on
Commit
ef427a4
1 Parent(s): 2c738fc
Files changed (4) hide show
  1. cross_attn_hook.py +632 -0
  2. ffn_hooker.py +224 -0
  3. norm_attn_hook.py +242 -0
  4. utils.py +310 -0
cross_attn_hook.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import OrderedDict
4
+ from functools import partial
5
+
6
+ import torch
7
+
8
+ import re
9
+
10
+ import math
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from diffusers.models.attention_processor import Attention
16
+ from diffusers.utils import deprecate
17
+ from diffusers.models.embeddings import apply_rotary_emb
18
+
19
+
20
+ def scaled_dot_product_attention_atten_weight_only(
21
+ query, key, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
22
+ ) -> torch.Tensor:
23
+ L, S = query.size(-2), key.size(-2)
24
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
25
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
26
+ if is_causal:
27
+ assert attn_mask is None
28
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
29
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
30
+ attn_bias.to(query.dtype)
31
+
32
+ if attn_mask is not None:
33
+ if attn_mask.dtype == torch.bool:
34
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
35
+ else:
36
+ attn_bias += attn_mask
37
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
38
+ attn_weight += attn_bias
39
+ attn_weight = torch.softmax(attn_weight, dim=-1)
40
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
41
+ return attn_weight
42
+
43
+
44
+ def apply_rope(xq, xk, freqs_cis):
45
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
46
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
47
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
48
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
49
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
50
+
51
+
52
+ def masking_fn(hidden_states, kwargs):
53
+ lamb = kwargs["lamb"].view(1, kwargs["lamb"].shape[0], 1, 1)
54
+ if kwargs.get("masking", None) == "sigmoid":
55
+ mask = torch.sigmoid(lamb)
56
+ elif kwargs.get("masking", None) == "binary":
57
+ mask = lamb
58
+ elif kwargs.get("masking", None) == "continues2binary":
59
+ # TODO: this might cause potential issue as it hard threshold at 0
60
+ mask = (lamb > 0).float()
61
+ elif kwargs.get("masking", None) == "no_masking":
62
+ mask = torch.ones_like(lamb)
63
+ else:
64
+ raise NotImplementedError
65
+ epsilon = kwargs.get("epsilon", 0.0)
66
+ hidden_states = hidden_states * mask + torch.randn_like(hidden_states) * epsilon * (
67
+ 1 - mask
68
+ )
69
+ return hidden_states
70
+
71
+
72
+ class FluxAttnProcessor2_0_Masking:
73
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
74
+
75
+ def __init__(self):
76
+ if not hasattr(F, "scaled_dot_product_attention"):
77
+ raise ImportError(
78
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
79
+ )
80
+
81
+ def __call__(
82
+ self,
83
+ attn: Attention,
84
+ hidden_states: torch.FloatTensor,
85
+ encoder_hidden_states: torch.FloatTensor = None,
86
+ attention_mask: Optional[torch.FloatTensor] = None,
87
+ image_rotary_emb: Optional[torch.Tensor] = None,
88
+ *args,
89
+ **kwargs,
90
+ ) -> torch.FloatTensor:
91
+ batch_size, _, _ = (
92
+ hidden_states.shape
93
+ if encoder_hidden_states is None
94
+ else encoder_hidden_states.shape
95
+ )
96
+
97
+ # `sample` projections.
98
+ query = attn.to_q(hidden_states)
99
+ key = attn.to_k(hidden_states)
100
+ value = attn.to_v(hidden_states)
101
+
102
+ inner_dim = key.shape[-1]
103
+ head_dim = inner_dim // attn.heads
104
+
105
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
106
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
107
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
108
+
109
+ if attn.norm_q is not None:
110
+ query = attn.norm_q(query)
111
+ if attn.norm_k is not None:
112
+ key = attn.norm_k(key)
113
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
114
+ if encoder_hidden_states is not None:
115
+ # `context` projections.
116
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
117
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
118
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
119
+
120
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
121
+ batch_size, -1, attn.heads, head_dim
122
+ ).transpose(1, 2)
123
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
124
+ batch_size, -1, attn.heads, head_dim
125
+ ).transpose(1, 2)
126
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
127
+ batch_size, -1, attn.heads, head_dim
128
+ ).transpose(1, 2)
129
+
130
+ if attn.norm_added_q is not None:
131
+ encoder_hidden_states_query_proj = attn.norm_added_q(
132
+ encoder_hidden_states_query_proj
133
+ )
134
+ if attn.norm_added_k is not None:
135
+ encoder_hidden_states_key_proj = attn.norm_added_k(
136
+ encoder_hidden_states_key_proj
137
+ )
138
+
139
+ # attention
140
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
141
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
142
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
143
+
144
+ if image_rotary_emb is not None:
145
+ query = apply_rotary_emb(query, image_rotary_emb)
146
+ key = apply_rotary_emb(key, image_rotary_emb)
147
+
148
+ hidden_states = F.scaled_dot_product_attention(
149
+ query, key, value, dropout_p=0.0, is_causal=False
150
+ )
151
+
152
+ if kwargs.get("lamb", None) is not None:
153
+ hidden_states = masking_fn(hidden_states, kwargs)
154
+
155
+ hidden_states = hidden_states.transpose(1, 2).reshape(
156
+ batch_size, -1, attn.heads * head_dim
157
+ )
158
+ hidden_states = hidden_states.to(query.dtype)
159
+
160
+ if encoder_hidden_states is not None:
161
+ encoder_hidden_states, hidden_states = (
162
+ hidden_states[:, : encoder_hidden_states.shape[1]],
163
+ hidden_states[:, encoder_hidden_states.shape[1] :],
164
+ )
165
+
166
+ # linear proj
167
+ hidden_states = attn.to_out[0](hidden_states)
168
+ # dropout
169
+ hidden_states = attn.to_out[1](hidden_states)
170
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
171
+
172
+ return hidden_states, encoder_hidden_states
173
+ else:
174
+ return hidden_states
175
+
176
+
177
+ class AttnProcessor2_0_Masking:
178
+ r"""
179
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
180
+ """
181
+
182
+ def __init__(self):
183
+ if not hasattr(F, "scaled_dot_product_attention"):
184
+ raise ImportError(
185
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
186
+ )
187
+
188
+ def __call__(
189
+ self,
190
+ attn: Attention,
191
+ hidden_states: torch.Tensor,
192
+ encoder_hidden_states: Optional[torch.Tensor] = None,
193
+ attention_mask: Optional[torch.Tensor] = None,
194
+ temb: Optional[torch.Tensor] = None,
195
+ *args,
196
+ **kwargs,
197
+ ):
198
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
199
+ deprecation_message = (
200
+ "The `scale` argument is deprecated and will be ignored. "
201
+ "Please remove it, as passing it will raise an error "
202
+ "in the future. `scale` should directly be passed while "
203
+ "calling the underlying pipeline component i.e., via "
204
+ "`cross_attention_kwargs`."
205
+ )
206
+ deprecate("scale", "1.0.0", deprecation_message)
207
+
208
+ residual = hidden_states
209
+ if attn.spatial_norm is not None:
210
+ hidden_states = attn.spatial_norm(hidden_states, temb)
211
+
212
+ input_ndim = hidden_states.ndim
213
+
214
+ if input_ndim == 4:
215
+ batch_size, channel, height, width = hidden_states.shape
216
+ hidden_states = hidden_states.view(
217
+ batch_size, channel, height * width
218
+ ).transpose(1, 2)
219
+
220
+ batch_size, sequence_length, _ = (
221
+ hidden_states.shape
222
+ if encoder_hidden_states is None
223
+ else encoder_hidden_states.shape
224
+ )
225
+
226
+ if attention_mask is not None:
227
+ attention_mask = attn.prepare_attention_mask(
228
+ attention_mask, sequence_length, batch_size
229
+ )
230
+ # scaled_dot_product_attention expects attention_mask shape to be
231
+ # (batch, heads, source_length, target_length)
232
+ attention_mask = attention_mask.view(
233
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
234
+ )
235
+
236
+ if attn.group_norm is not None:
237
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
238
+ 1, 2
239
+ )
240
+
241
+ query = attn.to_q(hidden_states)
242
+
243
+ if encoder_hidden_states is None:
244
+ encoder_hidden_states = hidden_states
245
+ elif attn.norm_cross:
246
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
247
+ encoder_hidden_states
248
+ )
249
+
250
+ key = attn.to_k(encoder_hidden_states)
251
+ value = attn.to_v(encoder_hidden_states)
252
+
253
+ inner_dim = key.shape[-1]
254
+ head_dim = inner_dim // attn.heads
255
+
256
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
257
+
258
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
259
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
260
+
261
+ if getattr(attn, "norm_q", None) is not None:
262
+ query = attn.norm_q(query)
263
+
264
+ if getattr(attn, "norm_k", None) is not None:
265
+ key = attn.norm_k(key)
266
+
267
+ hidden_states = F.scaled_dot_product_attention(
268
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
269
+ )
270
+
271
+ if kwargs.get("return_attention", True):
272
+ # add the attention output from F.scaled_dot_product_attention
273
+ attn_weight = scaled_dot_product_attention_atten_weight_only(
274
+ query, key, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
275
+ )
276
+ hidden_states_aft_attention_ops = hidden_states.clone()
277
+ attn_weight_old = attn_weight.to(hidden_states.device).clone()
278
+ else:
279
+ hidden_states_aft_attention_ops = None
280
+ attn_weight_old = None
281
+
282
+ # masking for the hidden_states after the attention ops
283
+ if kwargs.get("lamb", None) is not None:
284
+ hidden_states = masking_fn(hidden_states, kwargs)
285
+
286
+ hidden_states = hidden_states.transpose(1, 2).reshape(
287
+ batch_size, -1, attn.heads * head_dim
288
+ )
289
+ hidden_states = hidden_states.to(query.dtype)
290
+
291
+ # linear proj
292
+ hidden_states = attn.to_out[0](hidden_states)
293
+ # dropout
294
+ hidden_states = attn.to_out[1](hidden_states)
295
+
296
+ if input_ndim == 4:
297
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
298
+ batch_size, channel, height, width
299
+ )
300
+
301
+ if attn.residual_connection:
302
+ hidden_states = hidden_states + residual
303
+
304
+ hidden_states = hidden_states / attn.rescale_output_factor
305
+
306
+ return hidden_states, hidden_states_aft_attention_ops, attn_weight_old
307
+
308
+
309
+ class BaseCrossAttentionHooker:
310
+ def __init__(
311
+ self,
312
+ pipeline,
313
+ regex,
314
+ dtype,
315
+ head_num_filter,
316
+ masking,
317
+ model_name,
318
+ attn_name,
319
+ use_log,
320
+ eps,
321
+ ):
322
+ self.pipeline = pipeline
323
+ # unet for SD2 SDXL, transformer for SD3, FLUX DIT
324
+ self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
325
+ self.model_name = model_name
326
+ self.module_heads = OrderedDict()
327
+ self.masking = masking
328
+ self.hook_dict = {}
329
+ self.regex = regex
330
+ self.dtype = dtype
331
+ self.head_num_filter = head_num_filter
332
+ self.attn_name = attn_name
333
+ self.logger = logging.getLogger(__name__)
334
+ self.use_log = use_log # use log parameter to control hard_discrete
335
+ self.eps = eps
336
+
337
+ def add_hooks_to_cross_attention(self, hook_fn: callable):
338
+ """
339
+ Add forward hooks to every cross attention
340
+ :param hook_fn: a callable to be added to torch nn module as a hook
341
+ :return:
342
+ """
343
+ total_hooks = 0
344
+ for name, module in self.net.named_modules():
345
+ name_last_word = name.split(".")[-1]
346
+ if self.attn_name in name_last_word:
347
+ if re.match(self.regex, name):
348
+ hook_fn = partial(hook_fn, name=name)
349
+ hook = module.register_forward_hook(hook_fn, with_kwargs=True)
350
+ self.hook_dict[name] = hook
351
+ self.module_heads[name] = module.heads
352
+ self.logger.info(
353
+ f"Adding hook to {name}, module.heads: {module.heads}"
354
+ )
355
+ total_hooks += 1
356
+ self.logger.info(f"Total hooks added: {total_hooks}")
357
+
358
+ def clear_hooks(self):
359
+ """clear all hooks"""
360
+ for hook in self.hook_dict.values():
361
+ hook.remove()
362
+ self.hook_dict.clear()
363
+
364
+
365
+ class CrossAttentionExtractionHook(BaseCrossAttentionHooker):
366
+ def __init__(
367
+ self,
368
+ pipeline,
369
+ dtype,
370
+ head_num_filter,
371
+ masking,
372
+ dst,
373
+ regex=None,
374
+ epsilon=0.0,
375
+ binary=False,
376
+ return_attention=False,
377
+ model_name="sdxl",
378
+ attn_name="attn",
379
+ use_log=False,
380
+ eps=1e-6,
381
+ ):
382
+ super().__init__(
383
+ pipeline,
384
+ regex,
385
+ dtype,
386
+ head_num_filter,
387
+ masking=masking,
388
+ model_name=model_name,
389
+ attn_name=attn_name,
390
+ use_log=use_log,
391
+ eps=eps,
392
+ )
393
+ if model_name == "sdxl":
394
+ self.attention_processor = AttnProcessor2_0_Masking()
395
+ elif model_name == "flux":
396
+ self.attention_processor = FluxAttnProcessor2_0_Masking()
397
+ self.lambs = []
398
+ self.lambs_module_names = []
399
+ self.cross_attn = []
400
+ self.hook_counter = 0
401
+ self.device = (
402
+ self.pipeline.unet.device
403
+ if hasattr(self.pipeline, "unet")
404
+ else self.pipeline.transformer.device
405
+ )
406
+ self.dst = dst
407
+ self.epsilon = epsilon
408
+ self.binary = binary
409
+ self.return_attention = return_attention
410
+ self.model_name = model_name
411
+
412
+ def clean_cross_attn(self):
413
+ self.cross_attn = []
414
+
415
+ def validate_dst(self):
416
+ if os.path.exists(self.dst):
417
+ raise ValueError(f"Destination {self.dst} already exists")
418
+
419
+ def save(self, name: str = None):
420
+ if name is not None:
421
+ dst = os.path.join(os.path.dirname(self.dst), name)
422
+ else:
423
+ dst = self.dst
424
+ dst_dir = os.path.dirname(dst)
425
+ if not os.path.exists(dst_dir):
426
+ self.logger.info(f"Creating directory {dst_dir}")
427
+ os.makedirs(dst_dir)
428
+ torch.save(self.lambs, dst)
429
+
430
+ @property
431
+ def get_lambda_block_names(self):
432
+ return self.lambs_module_names
433
+
434
+ def load(self, device, threshold=2.5):
435
+ if os.path.exists(self.dst):
436
+ self.logger.info(f"loading lambda from {self.dst}")
437
+ self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
438
+ if self.binary:
439
+ # set binary masking for each lambda by using clamp
440
+ self.lambs = [
441
+ (torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs
442
+ ]
443
+ else:
444
+ self.logger.info("skipping loading, training from scratch")
445
+
446
+ def binarize(self, scope: str, ratio: float):
447
+ assert scope in ["local", "global"], "scope must be either local or global"
448
+ assert (
449
+ not self.binary
450
+ ), "binarization is not supported when using binary mask already"
451
+ if scope == "local":
452
+ # Local binarization
453
+ for i, lamb in enumerate(self.lambs):
454
+ num_heads = lamb.size(0)
455
+ num_activate_heads = int(num_heads * ratio)
456
+ # Sort the lambda values with stable sorting to maintain order for equal values
457
+ sorted_lamb, sorted_indices = torch.sort(
458
+ lamb, descending=True, stable=True
459
+ )
460
+ # Find the threshold value
461
+ threshold = sorted_lamb[num_activate_heads - 1]
462
+ # Create a mask based on the sorted indices
463
+ mask = torch.zeros_like(lamb)
464
+ mask[sorted_indices[:num_activate_heads]] = 1.0
465
+ # Binarize the lambda based on the threshold and the mask
466
+ self.lambs[i] = torch.where(
467
+ lamb > threshold, torch.ones_like(lamb), mask
468
+ )
469
+ else:
470
+ # Global binarization
471
+ all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
472
+ num_total = all_lambs.numel()
473
+ num_activate = int(num_total * ratio)
474
+ # Sort all lambda values globally with stable sorting
475
+ sorted_lambs, sorted_indices = torch.sort(
476
+ all_lambs, descending=True, stable=True
477
+ )
478
+ # Find the global threshold value
479
+ threshold = sorted_lambs[num_activate - 1]
480
+ # Create a global mask based on the sorted indices
481
+ global_mask = torch.zeros_like(all_lambs)
482
+ global_mask[sorted_indices[:num_activate]] = 1.0
483
+ # Binarize all lambdas based on the global threshold and mask
484
+ start_idx = 0
485
+ for i in range(len(self.lambs)):
486
+ end_idx = start_idx + self.lambs[i].numel()
487
+ lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
488
+ self.lambs[i] = torch.where(
489
+ self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask
490
+ )
491
+ start_idx = end_idx
492
+ self.binary = True
493
+
494
+ def bizarize_threshold(self, threshold: float):
495
+ """
496
+ Binarize lambda values based on a predefined threshold.
497
+ :param threshold: The threshold value for binarization
498
+ """
499
+ assert (
500
+ not self.binary
501
+ ), "Binarization is not supported when using binary mask already"
502
+
503
+ for i in range(len(self.lambs)):
504
+ self.lambs[i] = (self.lambs[i] >= threshold).float()
505
+
506
+ self.binary = True
507
+
508
+ def get_cross_attn_extraction_hook(self, init_value=1.0):
509
+ """get a hook function to extract cross attention"""
510
+
511
+ # the reason to use a function inside a function is to save the extracted cross attention
512
+ def hook_fn(module, args, kwargs, output, name):
513
+ # initialize lambda with acual head dim in the first run
514
+ if self.lambs[self.hook_counter] is None:
515
+ self.lambs[self.hook_counter] = (
516
+ torch.ones(
517
+ module.heads, device=self.pipeline.device, dtype=self.dtype
518
+ )
519
+ * init_value
520
+ )
521
+ # Only set requires_grad to True when the head number is larger than the filter
522
+ if self.head_num_filter <= module.heads:
523
+ self.lambs[self.hook_counter].requires_grad = True
524
+
525
+ # load attn lambda module name for logging
526
+ self.lambs_module_names[self.hook_counter] = name
527
+
528
+ if self.model_name == "sdxl":
529
+ hidden_states, _, attention_output = self.attention_processor(
530
+ module,
531
+ args[0],
532
+ encoder_hidden_states=kwargs["encoder_hidden_states"],
533
+ attention_mask=kwargs["attention_mask"],
534
+ lamb=self.lambs[self.hook_counter],
535
+ masking=self.masking,
536
+ epsilon=self.epsilon,
537
+ return_attention=self.return_attention,
538
+ use_log=self.use_log,
539
+ eps=self.eps,
540
+ )
541
+ if attention_output is not None:
542
+ self.cross_attn.append(attention_output)
543
+ self.hook_counter += 1
544
+ self.hook_counter %= len(self.lambs)
545
+ return hidden_states
546
+ elif self.model_name == "flux":
547
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
548
+ # flux has two different attention processors, FluxSingleAttnProcessor and FluxAttnProcessor
549
+ if "single" in name:
550
+ hidden_states = self.attention_processor(
551
+ module,
552
+ hidden_states=kwargs.get("hidden_states", None),
553
+ encoder_hidden_states=encoder_hidden_states,
554
+ attention_mask=kwargs.get("attention_mask", None),
555
+ image_rotary_emb=kwargs.get("image_rotary_emb", None),
556
+ lamb=self.lambs[self.hook_counter],
557
+ masking=self.masking,
558
+ epsilon=self.epsilon,
559
+ use_log=self.use_log,
560
+ eps=self.eps,
561
+ )
562
+ self.hook_counter += 1
563
+ self.hook_counter %= len(self.lambs)
564
+ return hidden_states
565
+ else:
566
+ hidden_states, encoder_hidden_states = self.attention_processor(
567
+ module,
568
+ hidden_states=kwargs.get("hidden_states", None),
569
+ encoder_hidden_states=encoder_hidden_states,
570
+ attention_mask=kwargs.get("attention_mask", None),
571
+ image_rotary_emb=kwargs.get("image_rotary_emb", None),
572
+ lamb=self.lambs[self.hook_counter],
573
+ masking=self.masking,
574
+ epsilon=self.epsilon,
575
+ use_log=self.use_log,
576
+ eps=self.eps,
577
+ )
578
+ self.hook_counter += 1
579
+ self.hook_counter %= len(self.lambs)
580
+ return hidden_states, encoder_hidden_states
581
+
582
+ return hook_fn
583
+
584
+ def add_hooks(self, init_value=1.0):
585
+ hook_fn = self.get_cross_attn_extraction_hook(init_value)
586
+ self.add_hooks_to_cross_attention(hook_fn)
587
+ # initialize the lambda
588
+ self.lambs = [None] * len(self.module_heads)
589
+ # initialize the lambda module names
590
+ self.lambs_module_names = [None] * len(self.module_heads)
591
+
592
+ def get_process_cross_attn_result(self, text_seq_length, timestep: int = -1):
593
+ if isinstance(timestep, str):
594
+ timestep = int(timestep)
595
+ # num_lambda_block contains lambda (head masking)
596
+ num_lambda_block = len(self.lambs)
597
+
598
+ # get the start and end position of the timestep
599
+ start_pos = timestep * num_lambda_block
600
+ end_pos = (timestep + 1) * num_lambda_block
601
+ if end_pos > len(self.cross_attn):
602
+ raise ValueError(f"timestep {timestep} is out of range")
603
+
604
+ # list[cross_attn_map] num_layer x [batch, num_heads, seq_vis_tokens, seq_text_tokens]
605
+ attn_maps = self.cross_attn[start_pos:end_pos]
606
+
607
+ def heatmap(attn_list, attn_idx, head_idx, text_idx):
608
+ # only select second element in the tuple (with text guided attention)
609
+ # layer_idx, 1, head_idx, seq_vis_tokens, seq_text_tokens
610
+ map = attn_list[attn_idx][1][head_idx][:][:, text_idx]
611
+ # get the size of the heatmap
612
+ size = int(map.shape[0] ** 0.5)
613
+ map = map.view(size, size, 1)
614
+ data = map.cpu().float().numpy()
615
+ return data
616
+
617
+ output_dict = {}
618
+ for lambda_block_idx, lambda_block_name in zip(
619
+ range(num_lambda_block), self.lambs_module_names
620
+ ):
621
+ data_list = []
622
+ for head_idx in range(len(self.lambs[lambda_block_idx])):
623
+ for token_idx in range(text_seq_length):
624
+ # number of heatmap is equal to the number of tokens in the text sequence X number of heads
625
+ data_list.append(
626
+ heatmap(attn_maps, lambda_block_idx, head_idx, token_idx)
627
+ )
628
+ output_dict[lambda_block_name] = {
629
+ "attn_map": data_list,
630
+ "lambda": self.lambs[lambda_block_idx],
631
+ }
632
+ return output_dict
ffn_hooker.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import OrderedDict
4
+ from functools import partial
5
+
6
+ import diffusers
7
+ import torch
8
+ from torch import nn
9
+
10
+ import re
11
+
12
+
13
+ class FeedForwardHooker:
14
+ def __init__(
15
+ self,
16
+ pipeline: nn.Module,
17
+ regex: str,
18
+ dtype: torch.dtype,
19
+ masking: str,
20
+ dst: str,
21
+ epsilon: float = 0.0,
22
+ eps: float = 1e-6,
23
+ use_log: bool = False,
24
+ binary: bool = False,
25
+ ):
26
+ self.pipeline = pipeline
27
+ self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
28
+ self.logger = logging.getLogger(__name__)
29
+ self.dtype = dtype
30
+ self.regex = regex
31
+ self.hook_dict = {}
32
+ self.masking = masking
33
+ self.dst = dst
34
+ self.epsilon = epsilon
35
+ self.eps = eps
36
+ self.use_log = use_log
37
+ self.lambs = []
38
+ self.lambs_module_names = [] # store the module names for each lambda block
39
+ self.hook_counter = 0
40
+ self.module_neurons = OrderedDict()
41
+ self.binary = (
42
+ binary # default, need to discuss if we need to keep this attribute or not
43
+ )
44
+
45
+ def add_hooks_to_ff(self, hook_fn: callable):
46
+ total_hooks = 0
47
+ for name, module in self.net.named_modules():
48
+ name_last_word = name.split(".")[-1]
49
+ if "ff" in name_last_word:
50
+ if re.match(self.regex, name):
51
+ hook_fn_with_name = partial(hook_fn, name=name)
52
+ actual_module = module.net[0]
53
+ hook = actual_module.register_forward_hook(
54
+ hook_fn_with_name, with_kwargs=True
55
+ )
56
+ self.hook_dict[name] = hook
57
+
58
+ if isinstance(
59
+ actual_module, diffusers.models.activations.GEGLU
60
+ ): # geglu
61
+ # due to the GEGLU chunking, we need to divide by 2
62
+ self.module_neurons[name] = actual_module.proj.out_features // 2
63
+ elif isinstance(
64
+ actual_module, diffusers.models.activations.GELU
65
+ ): # gelu
66
+ self.module_neurons[name] = actual_module.proj.out_features
67
+ else:
68
+ raise NotImplementedError(
69
+ f"Module {name} is not implemented, please check"
70
+ )
71
+ self.logger.info(
72
+ f"Adding hook to {name}, neurons: {self.module_neurons[name]}"
73
+ )
74
+ total_hooks += 1
75
+ self.logger.info(f"Total hooks added: {total_hooks}")
76
+ return self.hook_dict
77
+
78
+ def add_hooks(self, init_value=1.0):
79
+ hook_fn = self.get_ff_masking_hook(init_value)
80
+ self.add_hooks_to_ff(hook_fn)
81
+ # initialize the lambda
82
+ self.lambs = [None] * len(self.hook_dict)
83
+ # initialize the lambda module names
84
+ self.lambs_module_names = [None] * len(self.hook_dict)
85
+
86
+ def clear_hooks(self):
87
+ """clear all hooks"""
88
+ for hook in self.hook_dict.values():
89
+ hook.remove()
90
+ self.hook_dict.clear()
91
+
92
+ def save(self, name: str = None):
93
+ if name is not None:
94
+ dst = os.path.join(os.path.dirname(self.dst), name)
95
+ else:
96
+ dst = self.dst
97
+ dst_dir = os.path.dirname(dst)
98
+ if not os.path.exists(dst_dir):
99
+ self.logger.info(f"Creating directory {dst_dir}")
100
+ os.makedirs(dst_dir)
101
+ torch.save(self.lambs, dst)
102
+
103
+ @property
104
+ def get_lambda_block_names(self):
105
+ return self.lambs_module_names
106
+
107
+ def load(self, device, threshold=2.5):
108
+ if os.path.exists(self.dst):
109
+ self.logger.info(f"loading lambda from {self.dst}")
110
+ self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
111
+ if self.binary:
112
+ # set binary masking for each lambda by using clamp
113
+ self.lambs = [
114
+ (torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs
115
+ ]
116
+ else:
117
+ self.lambs = [torch.clamp(lamb, min=0.0) for lamb in self.lambs]
118
+ # self.lambs_module_names = [None for _ in self.lambs]
119
+ else:
120
+ self.logger.info("skipping loading, training from scratch")
121
+
122
+ def binarize(self, scope: str, ratio: float):
123
+ assert scope in ["local", "global"], "scope must be either local or global"
124
+ assert (
125
+ not self.binary
126
+ ), "binarization is not supported when using binary mask already"
127
+ if scope == "local":
128
+ # Local binarization
129
+ for i, lamb in enumerate(self.lambs):
130
+ num_heads = lamb.size(0)
131
+ num_activate_heads = int(num_heads * ratio)
132
+ # Sort the lambda values with stable sorting to maintain order for equal values
133
+ sorted_lamb, sorted_indices = torch.sort(
134
+ lamb, descending=True, stable=True
135
+ )
136
+ # Find the threshold value
137
+ threshold = sorted_lamb[num_activate_heads - 1]
138
+ # Create a mask based on the sorted indices
139
+ mask = torch.zeros_like(lamb)
140
+ mask[sorted_indices[:num_activate_heads]] = 1.0
141
+ # Binarize the lambda based on the threshold and the mask
142
+ self.lambs[i] = torch.where(
143
+ lamb > threshold, torch.ones_like(lamb), mask
144
+ )
145
+ else:
146
+ # Global binarization
147
+ all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
148
+ num_total = all_lambs.numel()
149
+ num_activate = int(num_total * ratio)
150
+ # Sort all lambda values globally with stable sorting
151
+ sorted_lambs, sorted_indices = torch.sort(
152
+ all_lambs, descending=True, stable=True
153
+ )
154
+ # Find the global threshold value
155
+ threshold = sorted_lambs[num_activate - 1]
156
+ # Create a global mask based on the sorted indices
157
+ global_mask = torch.zeros_like(all_lambs)
158
+ global_mask[sorted_indices[:num_activate]] = 1.0
159
+ # Binarize all lambdas based on the global threshold and mask
160
+ start_idx = 0
161
+ for i in range(len(self.lambs)):
162
+ end_idx = start_idx + self.lambs[i].numel()
163
+ lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
164
+ self.lambs[i] = torch.where(
165
+ self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask
166
+ )
167
+ start_idx = end_idx
168
+ self.binary = True
169
+
170
+ @staticmethod
171
+ def masking_fn(hidden_states, **kwargs):
172
+ hidden_states_dtype = hidden_states.dtype
173
+ lamb = kwargs["lamb"].view(1, 1, kwargs["lamb"].shape[0])
174
+ if kwargs.get("masking", None) == "sigmoid":
175
+ mask = torch.sigmoid(lamb)
176
+ elif kwargs.get("masking", None) == "binary":
177
+ mask = lamb
178
+ elif kwargs.get("masking", None) == "continues2binary":
179
+ # TODO: this might cause potential issue as it hard threshold at 0
180
+ mask = (lamb > 0).float()
181
+ elif kwargs.get("masking", None) == "no_masking":
182
+ mask = torch.ones_like(lamb)
183
+ else:
184
+ raise NotImplementedError
185
+ epsilon = kwargs.get("epsilon", 0.0)
186
+ hidden_states = hidden_states * mask + torch.randn_like(
187
+ hidden_states
188
+ ) * epsilon * (1 - mask)
189
+ return hidden_states.to(hidden_states_dtype)
190
+
191
+ def get_ff_masking_hook(self, init_value=1.0):
192
+ """
193
+ Get a hook function to mask feed forward layer
194
+ """
195
+
196
+ def hook_fn(module, args, kwargs, output, name):
197
+ # initialize lambda with acual head dim in the first run
198
+ if self.lambs[self.hook_counter] is None:
199
+ self.lambs[self.hook_counter] = (
200
+ torch.ones(
201
+ self.module_neurons[name],
202
+ device=self.pipeline.device,
203
+ dtype=self.dtype,
204
+ )
205
+ * init_value
206
+ )
207
+ self.lambs[self.hook_counter].requires_grad = True
208
+ # load ff lambda module name for logging
209
+ self.lambs_module_names[self.hook_counter] = name
210
+
211
+ # perform masking
212
+ output = self.masking_fn(
213
+ output,
214
+ masking=self.masking,
215
+ lamb=self.lambs[self.hook_counter],
216
+ epsilon=self.epsilon,
217
+ eps=self.eps,
218
+ use_log=self.use_log,
219
+ )
220
+ self.hook_counter += 1
221
+ self.hook_counter %= len(self.lambs)
222
+ return output
223
+
224
+ return hook_fn
norm_attn_hook.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO should be a parent class for all the hooks !! for the offical repo
2
+ # 1: FLUX Norm
3
+
4
+ import logging
5
+ import os
6
+ from collections import OrderedDict
7
+ from functools import partial
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ import re
13
+
14
+
15
+ class NormHooker:
16
+ def __init__(
17
+ self,
18
+ pipeline: nn.Module,
19
+ regex: str,
20
+ dtype: torch.dtype,
21
+ masking: str,
22
+ dst: str,
23
+ epsilon: float = 0.0,
24
+ eps: float = 1e-6,
25
+ use_log: bool = False,
26
+ binary: bool = False,
27
+ ):
28
+ self.pipeline = pipeline
29
+ self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
30
+ self.logger = logging.getLogger(__name__)
31
+ self.dtype = dtype
32
+ self.regex = regex
33
+ self.hook_dict = {}
34
+ self.masking = masking
35
+ self.dst = dst
36
+ self.epsilon = epsilon
37
+ self.eps = eps
38
+ self.use_log = use_log
39
+ self.lambs = []
40
+ self.lambs_module_names = [] # store the module names for each lambda block
41
+ self.hook_counter = 0
42
+ self.module_neurons = OrderedDict()
43
+ self.binary = (
44
+ binary # default, need to discuss if we need to keep this attribute or not
45
+ )
46
+
47
+ def add_hooks_to_norm(self, hook_fn: callable):
48
+ """
49
+ Add forward hooks to every feed forward layer matching the regex
50
+ :param hook_fn: a callable to be added to torch nn module as a hook
51
+ :return: dictionary of added hooks
52
+ """
53
+ total_hooks = 0
54
+ for name, module in self.net.named_modules():
55
+ name_last_word = name.split(".")[-1]
56
+ if "norm1_context" in name_last_word:
57
+ if re.match(self.regex, name):
58
+ hook_fn_with_name = partial(hook_fn, name=name)
59
+
60
+ if hasattr(module, "linear"):
61
+ actual_module = module.linear
62
+ else:
63
+ if isinstance(module, nn.Linear):
64
+ actual_module = module
65
+ else:
66
+ continue
67
+
68
+ hook = actual_module.register_forward_hook(
69
+ hook_fn_with_name, with_kwargs=True
70
+ )
71
+ self.hook_dict[name] = hook
72
+
73
+ # AdaLayerNormZero
74
+ if isinstance(actual_module, torch.nn.Linear):
75
+ self.module_neurons[name] = actual_module.out_features
76
+ else:
77
+ raise NotImplementedError(
78
+ f"Module {name} is not implemented, please check"
79
+ )
80
+ self.logger.info(
81
+ f"Adding hook to {name}, neurons: {self.module_neurons[name]}"
82
+ )
83
+ total_hooks += 1
84
+ self.logger.info(f"Total hooks added: {total_hooks}")
85
+ return self.hook_dict
86
+
87
+ def add_hooks(self, init_value=1.0):
88
+ hook_fn = self.get_norm_masking_hook(init_value)
89
+ self.add_hooks_to_norm(hook_fn)
90
+ # initialize the lambda
91
+ self.lambs = [None] * len(self.hook_dict)
92
+ # initialize the lambda module names
93
+ self.lambs_module_names = [None] * len(self.hook_dict)
94
+
95
+ def clear_hooks(self):
96
+ """clear all hooks"""
97
+ for hook in self.hook_dict.values():
98
+ hook.remove()
99
+ self.hook_dict.clear()
100
+
101
+ def save(self, name: str = None):
102
+ if name is not None:
103
+ dst = os.path.join(os.path.dirname(self.dst), name)
104
+ else:
105
+ dst = self.dst
106
+ dst_dir = os.path.dirname(dst)
107
+ if not os.path.exists(dst_dir):
108
+ self.logger.info(f"Creating directory {dst_dir}")
109
+ os.makedirs(dst_dir)
110
+ torch.save(self.lambs, dst)
111
+
112
+ @property
113
+ def get_lambda_block_names(self):
114
+ return self.lambs_module_names
115
+
116
+ def load(self, device, threshold):
117
+ if os.path.exists(self.dst):
118
+ self.logger.info(f"loading lambda from {self.dst}")
119
+ self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
120
+ if self.binary:
121
+ # set binary masking for each lambda by using clamp
122
+ self.lambs = [
123
+ (torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs
124
+ ]
125
+ else:
126
+ self.lambs = [torch.clamp(lamb, min=0.0) for lamb in self.lambs]
127
+ # self.lambs_module_names = [None for _ in self.lambs]
128
+ else:
129
+ self.logger.info("skipping loading, training from scratch")
130
+
131
+ def binarize(self, scope: str, ratio: float):
132
+ """
133
+ binarize lambda to be 0 or 1
134
+ :param scope: either locally (sparsity within layer) or globally (sparsity within model)
135
+ :param ratio: the ratio of the number of 1s to the total number of elements
136
+ """
137
+ assert scope in ["local", "global"], "scope must be either local or global"
138
+ assert (
139
+ not self.binary
140
+ ), "binarization is not supported when using binary mask already"
141
+ if scope == "local":
142
+ # Local binarization
143
+ for i, lamb in enumerate(self.lambs):
144
+ num_heads = lamb.size(0)
145
+ num_activate_heads = int(num_heads * ratio)
146
+ # Sort the lambda values with stable sorting to maintain order for equal values
147
+ sorted_lamb, sorted_indices = torch.sort(
148
+ lamb, descending=True, stable=True
149
+ )
150
+ # Find the threshold value
151
+ threshold = sorted_lamb[num_activate_heads - 1]
152
+ # Create a mask based on the sorted indices
153
+ mask = torch.zeros_like(lamb)
154
+ mask[sorted_indices[:num_activate_heads]] = 1.0
155
+ # Binarize the lambda based on the threshold and the mask
156
+ self.lambs[i] = torch.where(
157
+ lamb > threshold, torch.ones_like(lamb), mask
158
+ )
159
+ else:
160
+ # Global binarization
161
+ all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
162
+ num_total = all_lambs.numel()
163
+ num_activate = int(num_total * ratio)
164
+ # Sort all lambda values globally with stable sorting
165
+ sorted_lambs, sorted_indices = torch.sort(
166
+ all_lambs, descending=True, stable=True
167
+ )
168
+ # Find the global threshold value
169
+ threshold = sorted_lambs[num_activate - 1]
170
+ # Create a global mask based on the sorted indices
171
+ global_mask = torch.zeros_like(all_lambs)
172
+ global_mask[sorted_indices[:num_activate]] = 1.0
173
+ # Binarize all lambdas based on the global threshold and mask
174
+ start_idx = 0
175
+ for i in range(len(self.lambs)):
176
+ end_idx = start_idx + self.lambs[i].numel()
177
+ lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
178
+ self.lambs[i] = torch.where(
179
+ self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask
180
+ )
181
+ start_idx = end_idx
182
+ self.binary = True
183
+
184
+ @staticmethod
185
+ def masking_fn(hidden_states, **kwargs):
186
+ hidden_states_dtype = hidden_states.dtype
187
+ lamb = kwargs["lamb"].view(1, 1, kwargs["lamb"].shape[0])
188
+ if kwargs.get("masking", None) == "sigmoid":
189
+ mask = torch.sigmoid(lamb)
190
+ elif kwargs.get("masking", None) == "binary":
191
+ mask = lamb
192
+ elif kwargs.get("masking", None) == "continues2binary":
193
+ # TODO: this might cause potential issue as it hard threshold at 0
194
+ mask = (lamb > 0).float()
195
+ elif kwargs.get("masking", None) == "no_masking":
196
+ mask = torch.ones_like(lamb)
197
+ else:
198
+ raise NotImplementedError
199
+ epsilon = kwargs.get("epsilon", 0.0)
200
+
201
+ if hidden_states.dim() == 2:
202
+ mask = mask.squeeze(1)
203
+
204
+ hidden_states = hidden_states * mask + torch.randn_like(
205
+ hidden_states
206
+ ) * epsilon * (1 - mask)
207
+ return hidden_states.to(hidden_states_dtype)
208
+
209
+ def get_norm_masking_hook(self, init_value=1.0):
210
+ """
211
+ Get a hook function to mask feed forward layer
212
+ """
213
+
214
+ def hook_fn(module, args, kwargs, output, name):
215
+ # initialize lambda with acual head dim in the first run
216
+ if self.lambs[self.hook_counter] is None:
217
+ self.lambs[self.hook_counter] = (
218
+ torch.ones(
219
+ self.module_neurons[name],
220
+ device=self.pipeline.device,
221
+ dtype=self.dtype,
222
+ )
223
+ * init_value
224
+ )
225
+ self.lambs[self.hook_counter].requires_grad = True
226
+ # load norm lambda module name for logging
227
+ self.lambs_module_names[self.hook_counter] = name
228
+
229
+ # perform masking
230
+ output = self.masking_fn(
231
+ output,
232
+ masking=self.masking,
233
+ lamb=self.lambs[self.hook_counter],
234
+ epsilon=self.epsilon,
235
+ eps=self.eps,
236
+ use_log=self.use_log,
237
+ )
238
+ self.hook_counter += 1
239
+ self.hook_counter %= len(self.lambs)
240
+ return output
241
+
242
+ return hook_fn
utils.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from copy import deepcopy
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from diffusers.models.activations import GEGLU, GELU
8
+ from cross_attn_hook import CrossAttentionExtractionHook
9
+ from ffn_hooker import FeedForwardHooker
10
+ from norm_attn_hook import NormHooker
11
+
12
+
13
+ # create dummy module for skip connection
14
+ class SkipConnection(torch.nn.Module):
15
+ def __init__(self):
16
+ super(SkipConnection, self).__init__()
17
+
18
+ def forward(*args, **kwargs):
19
+ return args[1]
20
+
21
+
22
+ def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
23
+ total_num_lambs = 0
24
+ num_activate_lambs = 0
25
+ binary = getattr(
26
+ hooker, "binary", None
27
+ ) # if binary is not present, it will return None for ff_hooks
28
+ for lamb in hooker.lambs:
29
+ total_num_lambs += lamb.size(0)
30
+ if binary:
31
+ assert threshold is None, "threshold should be None for binary mask"
32
+ num_activate_lambs += lamb.sum().item()
33
+ else:
34
+ assert (
35
+ threshold is not None
36
+ ), "threshold must be provided for non-binary mask"
37
+ num_activate_lambs += (lamb >= threshold).sum().item()
38
+ return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
39
+
40
+
41
+ def create_pipeline(
42
+ pipe,
43
+ model_id,
44
+ device,
45
+ torch_dtype,
46
+ save_pt=None,
47
+ lambda_threshold: float = 1,
48
+ binary=True,
49
+ epsilon=0.0,
50
+ masking="binary",
51
+ attn_name="attn",
52
+ return_hooker=False,
53
+ scope=None,
54
+ ratio=None,
55
+ ):
56
+ """
57
+ create the pipeline and optionally load the saved mask
58
+ """
59
+ pipe.to(device)
60
+ pipe.vae.requires_grad_(False)
61
+ if hasattr(pipe, "unet"):
62
+ pipe.unet.requires_grad_(False)
63
+ else:
64
+ pipe.transformer.requires_grad_(False)
65
+ if save_pt:
66
+ # TODO should merge all the hooks checkpoint into one
67
+ if "ff.pt" in save_pt or "attn.pt" in save_pt:
68
+ save_pts = get_save_pts(save_pt)
69
+
70
+ cross_attn_hooker = CrossAttentionExtractionHook(
71
+ pipe,
72
+ model_name=model_id,
73
+ regex=".*",
74
+ dtype=torch_dtype,
75
+ head_num_filter=1,
76
+ masking=masking, # need to change to binary during inference
77
+ dst=save_pts["attn"],
78
+ epsilon=epsilon,
79
+ attn_name=attn_name,
80
+ binary=binary,
81
+ )
82
+ cross_attn_hooker.add_hooks(init_value=1)
83
+
84
+ ff_hooker = FeedForwardHooker(
85
+ pipe,
86
+ regex=".*",
87
+ dtype=torch_dtype,
88
+ masking=masking,
89
+ dst=save_pts["ff"],
90
+ epsilon=epsilon,
91
+ binary=binary,
92
+ )
93
+ ff_hooker.add_hooks(init_value=1)
94
+
95
+ if os.path.exists(save_pts["norm"]):
96
+ norm_hooker = NormHooker(
97
+ pipe,
98
+ regex=".*",
99
+ dtype=torch_dtype,
100
+ masking=masking,
101
+ dst=save_pts["norm"],
102
+ epsilon=epsilon,
103
+ binary=binary,
104
+ )
105
+ norm_hooker.add_hooks(init_value=1)
106
+ else:
107
+ norm_hooker = None
108
+
109
+ _ = pipe("abc", num_inference_steps=1)
110
+ cross_attn_hooker.load(device=device, threshold=lambda_threshold)
111
+ ff_hooker.load(device=device, threshold=lambda_threshold)
112
+ if norm_hooker:
113
+ norm_hooker.load(device=device, threshold=lambda_threshold)
114
+ if scope == "local" or scope == "global":
115
+ if isinstance(ratio, float):
116
+ attn_hooker_ratio = ratio
117
+ ff_hooker_ratio = ratio
118
+ else:
119
+ attn_hooker_ratio, ff_hooker_ratio = ratio[0], ratio[1]
120
+
121
+ if norm_hooker:
122
+ if len(ratio) < 3:
123
+ raise ValueError("Need to provide ratio for norm layer")
124
+ norm_hooker_ratio = ratio[2]
125
+
126
+ cross_attn_hooker.binarize(scope, attn_hooker_ratio)
127
+ ff_hooker.binarize(scope, ff_hooker_ratio)
128
+ if norm_hooker:
129
+ norm_hooker.binarize(scope, norm_hooker_ratio)
130
+ hookers = [cross_attn_hooker, ff_hooker]
131
+ if norm_hooker:
132
+ hookers.append(norm_hooker)
133
+
134
+ if return_hooker:
135
+ return pipe, hookers
136
+ else:
137
+ return pipe
138
+
139
+
140
+ def linear_layer_pruning(module, lamb):
141
+ heads_to_keep = torch.nonzero(lamb).squeeze()
142
+ if len(heads_to_keep.shape) == 0:
143
+ # if only one head is kept, or none
144
+ heads_to_keep = heads_to_keep.unsqueeze(0)
145
+
146
+ modules_to_remove = [module.to_k, module.to_q, module.to_v]
147
+ new_heads = int(lamb.sum().item())
148
+
149
+ if new_heads == 0:
150
+ return SkipConnection()
151
+
152
+ for module_to_remove in modules_to_remove:
153
+ # get head dimension
154
+ inner_dim = module_to_remove.out_features // module.heads
155
+ # place holder for the rows to keep
156
+ rows_to_keep = torch.zeros(
157
+ module_to_remove.out_features,
158
+ dtype=torch.bool,
159
+ device=module_to_remove.weight.device,
160
+ )
161
+
162
+ for idx in heads_to_keep:
163
+ rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
164
+
165
+ # overwrite the inner projection with masked projection
166
+ module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
167
+ if module_to_remove.bias is not None:
168
+ module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
169
+ module_to_remove.out_features = int(sum(rows_to_keep).item())
170
+
171
+ # Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
172
+ # with column masking, dim 1
173
+ if getattr(module, "to_out", None) is not None:
174
+ module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
175
+ module.to_out[0].in_features = int(sum(rows_to_keep).item())
176
+
177
+ # update parameters in the attention module
178
+ module.inner_dim = module.inner_dim // module.heads * new_heads
179
+ try:
180
+ module.query_dim = module.query_dim // module.heads * new_heads
181
+ module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
182
+ except:
183
+ pass
184
+ module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
185
+ module.heads = new_heads
186
+ return module
187
+
188
+
189
+ def ffn_linear_layer_pruning(module, lamb):
190
+ lambda_to_keep = torch.nonzero(lamb).squeeze()
191
+ if len(lambda_to_keep) == 0:
192
+ return SkipConnection()
193
+
194
+ num_lambda = len(lambda_to_keep)
195
+
196
+ if isinstance(module.net[0], GELU):
197
+ # linear layer weight remove before activation
198
+ module.net[0].proj.weight.data = module.net[0].proj.weight.data[
199
+ lambda_to_keep, :
200
+ ]
201
+ module.net[0].proj.out_features = num_lambda
202
+ if module.net[0].proj.bias is not None:
203
+ module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
204
+
205
+ update_act = GELU(module.net[0].proj.in_features, num_lambda)
206
+ update_act.proj = module.net[0].proj
207
+ module.net[0] = update_act
208
+ elif isinstance(module.net[0], GEGLU):
209
+ output_feature = module.net[0].proj.out_features
210
+ module.net[0].proj.weight.data = torch.cat(
211
+ [
212
+ module.net[0].proj.weight.data[: output_feature // 2, :][
213
+ lambda_to_keep, :
214
+ ],
215
+ module.net[0].proj.weight.data[output_feature // 2 :][
216
+ lambda_to_keep, :
217
+ ],
218
+ ],
219
+ dim=0,
220
+ )
221
+ module.net[0].proj.out_features = num_lambda * 2
222
+ if module.net[0].proj.bias is not None:
223
+ module.net[0].proj.bias.data = torch.cat(
224
+ [
225
+ module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
226
+ module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
227
+ ]
228
+ )
229
+
230
+ update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
231
+ update_act.proj = module.net[0].proj
232
+ module.net[0] = update_act
233
+
234
+ # proj weight after activation
235
+ module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
236
+ module.net[2].in_features = num_lambda
237
+
238
+ return module
239
+
240
+
241
+ # create SparsityLinear module
242
+ class SparsityLinear(torch.nn.Module):
243
+ def __init__(self, in_features, out_features, lambda_to_keep, num_lambda):
244
+ super(SparsityLinear, self).__init__()
245
+ self.linear = torch.nn.Linear(in_features, num_lambda)
246
+ self.out_features = out_features
247
+ self.lambda_to_keep = lambda_to_keep
248
+
249
+ def forward(self, x):
250
+ x = self.linear(x)
251
+ output = torch.zeros(
252
+ x.size(0), self.out_features, device=x.device, dtype=x.dtype
253
+ )
254
+ output[:, self.lambda_to_keep] = x
255
+ return output
256
+
257
+
258
+ def norm_layer_pruning(module, lamb):
259
+ """
260
+ Pruning the layer normalization layer for FLUX model
261
+ """
262
+ lambda_to_keep = torch.nonzero(lamb).squeeze()
263
+ if len(lambda_to_keep) == 0:
264
+ return SkipConnection()
265
+
266
+ num_lambda = len(lambda_to_keep)
267
+
268
+ # get num_features
269
+ in_features = module.linear.in_features
270
+ out_features = module.linear.out_features
271
+
272
+ linear = SparsityLinear(in_features, out_features, lambda_to_keep, num_lambda)
273
+ linear.linear.weight.data = module.linear.weight.data[lambda_to_keep]
274
+ linear.linear.bias.data = module.linear.bias.data[lambda_to_keep]
275
+ module.linear = linear
276
+ return module
277
+
278
+
279
+ def get_save_pts(save_pt):
280
+ if "ff.pt" in save_pt:
281
+ ff_save_pt = deepcopy(save_pt) # avoid in-place operation
282
+ attn_save_pt = save_pt.split(os.sep)
283
+ attn_save_pt[-1] = attn_save_pt[-1].replace("ff", "attn")
284
+ attn_save_pt_output = os.sep.join(attn_save_pt)
285
+ attn_save_pt[-1] = attn_save_pt[-1].replace("attn", "norm")
286
+ norm_save_pt = os.sep.join(attn_save_pt)
287
+
288
+ return {
289
+ "ff": ff_save_pt,
290
+ "attn": attn_save_pt_output,
291
+ "norm": norm_save_pt,
292
+ }
293
+ else:
294
+ attn_save_pt = deepcopy(save_pt)
295
+ ff_save_pt = save_pt.split(os.sep)
296
+ ff_save_pt[-1] = ff_save_pt[-1].replace("attn", "ff")
297
+ ff_save_pt_output = os.sep.join(ff_save_pt)
298
+ ff_save_pt[-1] = ff_save_pt[-1].replace("ff", "norm")
299
+ norm_save_pt = os.sep.join(attn_save_pt)
300
+
301
+ return {
302
+ "ff": ff_save_pt_output,
303
+ "attn": attn_save_pt,
304
+ "norm": norm_save_pt,
305
+ }
306
+
307
+
308
+ def save_img(pipe, g_cpu, steps, prompt, save_path):
309
+ image = pipe(prompt, generator=g_cpu, num_inference_steps=steps)
310
+ image["images"][0].save(save_path)