zhangyang-0123 commited on
Commit
82d824b
1 Parent(s): 3964559

add ecodiff demo

Browse files
app.py CHANGED
@@ -1,154 +1,224 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
1
  import gradio as gr
2
+ from dataclasses import dataclass
 
3
 
 
 
4
  import torch
5
+ from tqdm import tqdm
6
+
7
+ from src.utils import (
8
+ create_pipeline,
9
+ calculate_mask_sparsity,
10
+ ffn_linear_layer_pruning,
11
+ linear_layer_pruning,
12
+ )
13
+ from diffusers import StableDiffusionXLPipeline
14
+
15
+
16
+ def get_model_param_summary(model, verbose=False):
17
+ params_dict = dict()
18
+ overall_params = 0
19
+ for name, params in model.named_parameters():
20
+ num_params = params.numel()
21
+ overall_params += num_params
22
+ if verbose:
23
+ print(f"GPU Memory Requirement for {name}: {params} MiB")
24
+ params_dict.update({name: num_params})
25
+ params_dict.update({"overall": overall_params})
26
+ return params_dict
27
+
28
+
29
+ @dataclass
30
+ class GradioArgs:
31
+ ckpt: str = "./mask/ff.pt"
32
+ device: str = "cuda:0"
33
+ seed: list = None
34
+ prompt: str = None
35
+ mix_precision: str = "bf16"
36
+ num_intervention_steps: int = 50
37
+ model: str = "sdxl"
38
+ binary: bool = False
39
+ masking: str = "binary"
40
+ scope: str = "global"
41
+ ratio: list = None
42
+ width: int = None
43
+ height: int = None
44
+ epsilon: float = 0.0
45
+ lambda_threshold: float = 0.001
46
+
47
+ def __post_init__(self):
48
+ if self.seed is None:
49
+ self.seed = [44]
50
+ if self.ratio is None:
51
+ self.ratio = [0.68, 0.88]
52
+
53
+
54
+ def prune_model(pipe, hookers):
55
+ # remove parameters in attention blocks
56
+ cross_attn_hooker = hookers[0]
57
+ for name in tqdm(cross_attn_hooker.hook_dict.keys(), desc="Pruning attention layers"):
58
+ if getattr(pipe, "unet", None):
59
+ module = pipe.unet.get_submodule(name)
60
+ else:
61
+ module = pipe.transformer.get_submodule(name)
62
+ lamb = cross_attn_hooker.lambs[cross_attn_hooker.lambs_module_names.index(name)]
63
+ assert module.heads == lamb.shape[0]
64
+ module = linear_layer_pruning(module, lamb)
65
+
66
+ parent_module_name, child_name = name.rsplit(".", 1)
67
+ if getattr(pipe, "unet", None):
68
+ parent_module = pipe.unet.get_submodule(parent_module_name)
69
+ else:
70
+ parent_module = pipe.transformer.get_submodule(parent_module_name)
71
+ setattr(parent_module, child_name, module)
72
+
73
+ # remove parameters in ffn blocks
74
+ ffn_hook = hookers[1]
75
+ for name in tqdm(ffn_hook.hook_dict.keys(), desc="Pruning on FFN linear lazer"):
76
+ if getattr(pipe, "unet", None):
77
+ module = pipe.unet.get_submodule(name)
78
+ else:
79
+ module = pipe.transformer.get_submodule(name)
80
+ lamb = ffn_hook.lambs[ffn_hook.lambs_module_names.index(name)]
81
+ module = ffn_linear_layer_pruning(module, lamb)
82
+
83
+ parent_module_name, child_name = name.rsplit(".", 1)
84
+ if getattr(pipe, "unet", None):
85
+ parent_module = pipe.unet.get_submodule(parent_module_name)
86
+ else:
87
+ parent_module = pipe.transformer.get_submodule(parent_module_name)
88
+ setattr(parent_module, child_name, module)
89
+
90
+ cross_attn_hooker.clear_hooks()
91
+ ffn_hook.clear_hooks()
92
+ return pipe
93
+
94
+
95
+ def binary_mask_eval(args):
96
+ # load sdxl model
97
+ pipe = StableDiffusionXLPipeline.from_pretrained(
98
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
99
+ ).to(args.device)
100
+
101
+ device = args.device
102
+ torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
103
+ mask_pipe, hookers = create_pipeline(
104
+ pipe,
105
+ args.model,
106
+ device,
107
+ torch_dtype,
108
+ args.ckpt,
109
+ binary=args.binary,
110
+ lambda_threshold=args.lambda_threshold,
111
+ epsilon=args.epsilon,
112
+ masking=args.masking,
113
+ return_hooker=True,
114
+ scope=args.scope,
115
+ ratio=args.ratio,
116
+ )
117
 
118
+ # Print mask sparsity info
119
+ threshold = None if args.binary else args.lambda_threshold
120
+ threshold = None if args.scope is not None else threshold
121
+ name = ["ff", "attn"]
122
+ for n, hooker in zip(name, hookers):
123
+ total_num_heads, num_activate_heads, mask_sparsity = calculate_mask_sparsity(hooker, threshold)
124
+ print(f"model: {args.model}, {n} masking: {args.masking}")
125
+ print(
126
+ f"total num heads: {total_num_heads},"
127
+ + f"num activate heads: {num_activate_heads}, mask sparsity: {mask_sparsity}"
128
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ # Prune the model
131
+ pruned_pipe = prune_model(mask_pipe, hookers)
 
 
 
 
 
 
132
 
133
+ # reload the original model
134
+ pipe = StableDiffusionXLPipeline.from_pretrained(
135
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
136
+ ).to(args.device)
137
 
138
+ # get model param summary
139
+ print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
140
+ print(f"pruned model param: {get_model_param_summary(pruned_pipe.unet)['overall']}")
141
+ print("prune complete")
142
+ return pipe, pruned_pipe
143
+
144
+
145
+ def generate_images(prompt, seed, steps, pipe, pruned_pipe):
146
+ # Run the model and return images directly
147
+ g_cpu = torch.Generator("cuda:0").manual_seed(seed)
148
+ original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
149
+ g_cpu = torch.Generator("cuda:0").manual_seed(seed)
150
+ ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
151
+ return original_image, ecodiff_image
152
+
153
+
154
+ def on_prune_click(prompt, seed, steps):
155
+ args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
156
+ pipe, pruned_pipe = binary_mask_eval(args)
157
+ return pipe, pruned_pipe, [("Model Initialized", "green")]
158
 
 
 
 
 
 
 
 
159
 
160
+ def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
161
+ original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
162
+ return original_image, ecodiff_image
163
+
164
+
165
+ def create_demo():
166
+ with gr.Blocks() as demo:
167
+ gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
168
+ with gr.Row():
169
+ gr.Markdown(
170
+ """
171
+ # 🚧 Under Construction 🚧
172
+ This demo is currently being developed and may not be fully functional. More models and pruning ratios will be supported soon.
173
+ The current pruned model checkpoint is not optimal and does not provide the best performance.
174
+
175
+ **Note: Please first initialize the model before generating images.**
176
+ """
177
  )
178
+ with gr.Row():
179
+ model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
180
+ pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
181
+ prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
182
+ status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
183
+ with gr.Row():
184
+ prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
185
+ seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
186
+ steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1)
187
+ generate_btn = gr.Button("Generate Images")
188
+ gr.Examples(
189
+ examples=[
190
+ "A clock tower floating in a sea of clouds",
191
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
192
+ "An astronaut riding a green horse",
193
+ "A delicious ceviche cheesecake slice",
194
+ ],
195
+ inputs=[prompt],
196
+ )
197
+ with gr.Row():
198
+ original_output = gr.Image(label="Original Output")
199
+ ecodiff_output = gr.Image(label="EcoDiff Output")
200
+
201
+ pipe_state = gr.State(None)
202
+ pruned_pipe_state = gr.State(None)
203
+ prompt.submit(
204
+ fn=on_generate_click,
205
+ inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
206
+ outputs=[original_output, ecodiff_output],
207
+ )
208
+ prune_btn.click(
209
+ fn=on_prune_click,
210
+ inputs=[prompt, seed, steps],
211
+ outputs=[pipe_state, pruned_pipe_state, status_label],
212
+ )
213
+ generate_btn.click(
214
+ fn=on_generate_click,
215
+ inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
216
+ outputs=[original_output, ecodiff_output],
217
+ )
218
+
219
+ return demo
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  if __name__ == "__main__":
223
+ demo = create_demo()
224
+ demo.launch(share=True)
mask/attn.pt ADDED
Binary file (40.5 kB). View file
 
mask/ff.pt ADDED
Binary file (686 kB). View file
 
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
1
+ diffusers==0.31.0
2
+ torch==2.4.1
3
+ transformers==4.45.2
4
+ accelerate==0.33.0
 
 
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (198 Bytes). View file
 
src/__pycache__/cross_attn_hook.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
src/__pycache__/ffn_hooker.cpython-310.pyc ADDED
Binary file (6.4 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.53 kB). View file
 
src/cross_attn_hook.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import OrderedDict
4
+ from functools import partial
5
+
6
+ import torch
7
+
8
+ import re
9
+
10
+ import math
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from diffusers.models.attention_processor import Attention
16
+ from diffusers.utils import deprecate
17
+
18
+
19
+ def scaled_dot_product_attention_atten_weight_only(
20
+ query, key, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
21
+ ) -> torch.Tensor:
22
+ L, S = query.size(-2), key.size(-2)
23
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
24
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
25
+ if is_causal:
26
+ assert attn_mask is None
27
+ temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
28
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
29
+ attn_bias.to(query.dtype)
30
+
31
+ if attn_mask is not None:
32
+ if attn_mask.dtype == torch.bool:
33
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
34
+ else:
35
+ attn_bias += attn_mask
36
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
37
+ attn_weight += attn_bias
38
+ attn_weight = torch.softmax(attn_weight, dim=-1)
39
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
40
+ return attn_weight
41
+
42
+
43
+ def apply_rope(xq, xk, freqs_cis):
44
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
45
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
46
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
47
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
48
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
49
+
50
+
51
+ def masking_fn(hidden_states, kwargs):
52
+ lamb = kwargs["lamb"].view(1, kwargs["lamb"].shape[0], 1, 1)
53
+ if kwargs.get("masking", None) == "sigmoid":
54
+ mask = torch.sigmoid(lamb)
55
+ elif kwargs.get("masking", None) == "binary":
56
+ mask = lamb
57
+ elif kwargs.get("masking", None) == "continues2binary":
58
+ # TODO: this might cause potential issue as it hard threshold at 0
59
+ mask = (lamb > 0).float()
60
+ elif kwargs.get("masking", None) == "no_masking":
61
+ mask = torch.ones_like(lamb)
62
+ else:
63
+ raise NotImplementedError
64
+ epsilon = kwargs.get("epsilon", 0.0)
65
+ hidden_states = hidden_states * mask + torch.randn_like(hidden_states) * epsilon * (1 - mask)
66
+ return hidden_states
67
+
68
+
69
+ class AttnProcessor2_0_Masking:
70
+ r"""
71
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
72
+ """
73
+
74
+ def __init__(self):
75
+ if not hasattr(F, "scaled_dot_product_attention"):
76
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
77
+
78
+ def __call__(
79
+ self,
80
+ attn: Attention,
81
+ hidden_states: torch.Tensor,
82
+ encoder_hidden_states: Optional[torch.Tensor] = None,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ temb: Optional[torch.Tensor] = None,
85
+ *args,
86
+ **kwargs,
87
+ ):
88
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
89
+ deprecation_message = (
90
+ "The `scale` argument is deprecated and will be ignored. "
91
+ "Please remove it, as passing it will raise an error "
92
+ "in the future. `scale` should directly be passed while "
93
+ "calling the underlying pipeline component i.e., via "
94
+ "`cross_attention_kwargs`."
95
+ )
96
+ deprecate("scale", "1.0.0", deprecation_message)
97
+
98
+ residual = hidden_states
99
+ if attn.spatial_norm is not None:
100
+ hidden_states = attn.spatial_norm(hidden_states, temb)
101
+
102
+ input_ndim = hidden_states.ndim
103
+
104
+ if input_ndim == 4:
105
+ batch_size, channel, height, width = hidden_states.shape
106
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
107
+
108
+ batch_size, sequence_length, _ = (
109
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
110
+ )
111
+
112
+ if attention_mask is not None:
113
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
114
+ # scaled_dot_product_attention expects attention_mask shape to be
115
+ # (batch, heads, source_length, target_length)
116
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
117
+
118
+ if attn.group_norm is not None:
119
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
120
+
121
+ query = attn.to_q(hidden_states)
122
+
123
+ if encoder_hidden_states is None:
124
+ encoder_hidden_states = hidden_states
125
+ elif attn.norm_cross:
126
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
127
+
128
+ key = attn.to_k(encoder_hidden_states)
129
+ value = attn.to_v(encoder_hidden_states)
130
+
131
+ inner_dim = key.shape[-1]
132
+ head_dim = inner_dim // attn.heads
133
+
134
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
135
+
136
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
137
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
138
+
139
+ if getattr(attn, "norm_q", None) is not None:
140
+ query = attn.norm_q(query)
141
+
142
+ if getattr(attn, "norm_k", None) is not None:
143
+ key = attn.norm_k(key)
144
+
145
+ hidden_states = F.scaled_dot_product_attention(
146
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
147
+ )
148
+
149
+ if kwargs.get("return_attention", True):
150
+ # add the attention output from F.scaled_dot_product_attention
151
+ attn_weight = scaled_dot_product_attention_atten_weight_only(
152
+ query, key, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
153
+ )
154
+ hidden_states_aft_attention_ops = hidden_states.clone()
155
+ attn_weight_old = attn_weight.to(hidden_states.device).clone()
156
+ else:
157
+ hidden_states_aft_attention_ops = None
158
+ attn_weight_old = None
159
+
160
+ # masking for the hidden_states after the attention ops
161
+ if kwargs.get("lamb", None) is not None:
162
+ hidden_states = masking_fn(hidden_states, kwargs)
163
+
164
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
165
+ hidden_states = hidden_states.to(query.dtype)
166
+
167
+ # linear proj
168
+ hidden_states = attn.to_out[0](hidden_states)
169
+ # dropout
170
+ hidden_states = attn.to_out[1](hidden_states)
171
+
172
+ if input_ndim == 4:
173
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
174
+
175
+ if attn.residual_connection:
176
+ hidden_states = hidden_states + residual
177
+
178
+ hidden_states = hidden_states / attn.rescale_output_factor
179
+
180
+ return hidden_states, hidden_states_aft_attention_ops, attn_weight_old
181
+
182
+ class BaseCrossAttentionHooker:
183
+ def __init__(self, pipeline, regex, dtype, head_num_filter, masking, model_name, attn_name, use_log, eps):
184
+ self.pipeline = pipeline
185
+ # unet for SD2 SDXL, transformer for SD3, FLUX DIT
186
+ self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
187
+ self.model_name = model_name
188
+ self.module_heads = OrderedDict()
189
+ self.masking = masking
190
+ self.hook_dict = {}
191
+ self.regex = regex
192
+ self.dtype = dtype
193
+ self.head_num_filter = head_num_filter
194
+ self.attn_name = attn_name
195
+ self.logger = logging.getLogger(__name__)
196
+ self.use_log = use_log # use log parameter to control hard_discrete
197
+ self.eps = eps
198
+
199
+ def add_hooks_to_cross_attention(self, hook_fn: callable):
200
+ """
201
+ Add forward hooks to every cross attention
202
+ :param hook_fn: a callable to be added to torch nn module as a hook
203
+ :return:
204
+ """
205
+ total_hooks = 0
206
+ for name, module in self.net.named_modules():
207
+ name_last_word = name.split(".")[-1]
208
+ if self.attn_name in name_last_word:
209
+ if re.match(self.regex, name):
210
+ hook_fn = partial(hook_fn, name=name)
211
+ hook = module.register_forward_hook(hook_fn, with_kwargs=True)
212
+ self.hook_dict[name] = hook
213
+ self.module_heads[name] = module.heads
214
+ self.logger.info(f"Adding hook to {name}, module.heads: {module.heads}")
215
+ total_hooks += 1
216
+ self.logger.info(f"Total hooks added: {total_hooks}")
217
+
218
+ def clear_hooks(self):
219
+ """clear all hooks"""
220
+ for hook in self.hook_dict.values():
221
+ hook.remove()
222
+ self.hook_dict.clear()
223
+
224
+
225
+ class CrossAttentionExtractionHook(BaseCrossAttentionHooker):
226
+ def __init__(
227
+ self,
228
+ pipeline,
229
+ dtype,
230
+ head_num_filter,
231
+ masking,
232
+ dst,
233
+ regex=None,
234
+ epsilon=0.0,
235
+ binary=False,
236
+ return_attention=False,
237
+ model_name="sdxl",
238
+ attn_name="attn",
239
+ use_log=False,
240
+ eps=1e-6,
241
+ ):
242
+ super().__init__(
243
+ pipeline,
244
+ regex,
245
+ dtype,
246
+ head_num_filter,
247
+ masking=masking,
248
+ model_name=model_name,
249
+ attn_name=attn_name,
250
+ use_log=use_log,
251
+ eps=eps,
252
+ )
253
+ self.attention_processor = AttnProcessor2_0_Masking()
254
+ self.lambs = []
255
+ self.lambs_module_names = []
256
+ self.cross_attn = []
257
+ self.hook_counter = 0
258
+ self.device = self.pipeline.unet.device if hasattr(self.pipeline, "unet") else self.pipeline.transformer.device
259
+ self.dst = dst
260
+ self.epsilon = epsilon
261
+ self.binary = binary
262
+ self.return_attention = return_attention
263
+ self.model_name = model_name
264
+
265
+ def clean_cross_attn(self):
266
+ self.cross_attn = []
267
+
268
+ def validate_dst(self):
269
+ if os.path.exists(self.dst):
270
+ raise ValueError(f"Destination {self.dst} already exists")
271
+
272
+ def save(self, name: str = None):
273
+ if name is not None:
274
+ dst = os.path.join(os.path.dirname(self.dst), name)
275
+ else:
276
+ dst = self.dst
277
+ dst_dir = os.path.dirname(dst)
278
+ if not os.path.exists(dst_dir):
279
+ self.logger.info(f"Creating directory {dst_dir}")
280
+ os.makedirs(dst_dir)
281
+ torch.save(self.lambs, dst)
282
+
283
+ @property
284
+ def get_lambda_block_names(self):
285
+ return self.lambs_module_names
286
+
287
+ def load(self, device, threshold=2.5):
288
+ if os.path.exists(self.dst):
289
+ self.logger.info(f"loading lambda from {self.dst}")
290
+ self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
291
+ if self.binary:
292
+ # set binary masking for each lambda by using clamp
293
+ self.lambs = [(torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs]
294
+ else:
295
+ self.logger.info("skipping loading, training from scratch")
296
+
297
+ def binarize(self, scope: str, ratio: float):
298
+ assert scope in ["local", "global"], "scope must be either local or global"
299
+ assert not self.binary, "binarization is not supported when using binary mask already"
300
+ if scope == "local":
301
+ # Local binarization
302
+ for i, lamb in enumerate(self.lambs):
303
+ num_heads = lamb.size(0)
304
+ num_activate_heads = int(num_heads * ratio)
305
+ # Sort the lambda values with stable sorting to maintain order for equal values
306
+ sorted_lamb, sorted_indices = torch.sort(lamb, descending=True, stable=True)
307
+ # Find the threshold value
308
+ threshold = sorted_lamb[num_activate_heads - 1]
309
+ # Create a mask based on the sorted indices
310
+ mask = torch.zeros_like(lamb)
311
+ mask[sorted_indices[:num_activate_heads]] = 1.0
312
+ # Binarize the lambda based on the threshold and the mask
313
+ self.lambs[i] = torch.where(lamb > threshold, torch.ones_like(lamb), mask)
314
+ else:
315
+ # Global binarization
316
+ all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
317
+ num_total = all_lambs.numel()
318
+ num_activate = int(num_total * ratio)
319
+ # Sort all lambda values globally with stable sorting
320
+ sorted_lambs, sorted_indices = torch.sort(all_lambs, descending=True, stable=True)
321
+ # Find the global threshold value
322
+ threshold = sorted_lambs[num_activate - 1]
323
+ # Create a global mask based on the sorted indices
324
+ global_mask = torch.zeros_like(all_lambs)
325
+ global_mask[sorted_indices[:num_activate]] = 1.0
326
+ # Binarize all lambdas based on the global threshold and mask
327
+ start_idx = 0
328
+ for i in range(len(self.lambs)):
329
+ end_idx = start_idx + self.lambs[i].numel()
330
+ lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
331
+ self.lambs[i] = torch.where(self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask)
332
+ start_idx = end_idx
333
+ self.binary = True
334
+
335
+ def bizarize_threshold(self, threshold: float):
336
+ """
337
+ Binarize lambda values based on a predefined threshold.
338
+ :param threshold: The threshold value for binarization
339
+ """
340
+ assert not self.binary, "Binarization is not supported when using binary mask already"
341
+
342
+ for i in range(len(self.lambs)):
343
+ self.lambs[i] = (self.lambs[i] >= threshold).float()
344
+
345
+ self.binary = True
346
+
347
+ def get_cross_attn_extraction_hook(self, init_value=1.0):
348
+ """get a hook function to extract cross attention"""
349
+
350
+ # the reason to use a function inside a function is to save the extracted cross attention
351
+ def hook_fn(module, args, kwargs, output, name):
352
+ # initialize lambda with acual head dim in the first run
353
+ if self.lambs[self.hook_counter] is None:
354
+ self.lambs[self.hook_counter] = (
355
+ torch.ones(module.heads, device=self.pipeline.device, dtype=self.dtype) * init_value
356
+ )
357
+ # Only set requires_grad to True when the head number is larger than the filter
358
+ if self.head_num_filter <= module.heads:
359
+ self.lambs[self.hook_counter].requires_grad = True
360
+
361
+ # load attn lambda module name for logging
362
+ self.lambs_module_names[self.hook_counter] = name
363
+
364
+ hidden_states, _, attention_output = self.attention_processor(
365
+ module,
366
+ args[0],
367
+ encoder_hidden_states=kwargs["encoder_hidden_states"],
368
+ attention_mask=kwargs["attention_mask"],
369
+ lamb=self.lambs[self.hook_counter],
370
+ masking=self.masking,
371
+ epsilon=self.epsilon,
372
+ return_attention=self.return_attention,
373
+ use_log=self.use_log,
374
+ eps=self.eps,
375
+ )
376
+ if attention_output is not None:
377
+ self.cross_attn.append(attention_output)
378
+ self.hook_counter += 1
379
+ self.hook_counter %= len(self.lambs)
380
+ return hidden_states
381
+
382
+ return hook_fn
383
+
384
+ def add_hooks(self, init_value=1.0):
385
+ hook_fn = self.get_cross_attn_extraction_hook(init_value)
386
+ self.add_hooks_to_cross_attention(hook_fn)
387
+ # initialize the lambda
388
+ self.lambs = [None] * len(self.module_heads)
389
+ # initialize the lambda module names
390
+ self.lambs_module_names = [None] * len(self.module_heads)
391
+
392
+ def get_process_cross_attn_result(self, text_seq_length, timestep: int = -1):
393
+ if isinstance(timestep, str):
394
+ timestep = int(timestep)
395
+ # num_lambda_block contains lambda (head masking)
396
+ num_lambda_block = len(self.lambs)
397
+
398
+ # get the start and end position of the timestep
399
+ start_pos = timestep * num_lambda_block
400
+ end_pos = (timestep + 1) * num_lambda_block
401
+ if end_pos > len(self.cross_attn):
402
+ raise ValueError(f"timestep {timestep} is out of range")
403
+
404
+ # list[cross_attn_map] num_layer x [batch, num_heads, seq_vis_tokens, seq_text_tokens]
405
+ attn_maps = self.cross_attn[start_pos:end_pos]
406
+
407
+ def heatmap(attn_list, attn_idx, head_idx, text_idx):
408
+ # only select second element in the tuple (with text guided attention)
409
+ # layer_idx, 1, head_idx, seq_vis_tokens, seq_text_tokens
410
+ map = attn_list[attn_idx][1][head_idx][:][:, text_idx]
411
+ # get the size of the heatmap
412
+ size = int(map.shape[0] ** 0.5)
413
+ map = map.view(size, size, 1)
414
+ data = map.cpu().float().numpy()
415
+ return data
416
+
417
+ output_dict = {}
418
+ for lambda_block_idx, lambda_block_name in zip(range(num_lambda_block), self.lambs_module_names):
419
+ data_list = []
420
+ for head_idx in range(len(self.lambs[lambda_block_idx])):
421
+ for token_idx in range(text_seq_length):
422
+ # number of heatmap is equal to the number of tokens in the text sequence X number of heads
423
+ data_list.append(heatmap(attn_maps, lambda_block_idx, head_idx, token_idx))
424
+ output_dict[lambda_block_name] = {"attn_map": data_list, "lambda": self.lambs[lambda_block_idx]}
425
+ return output_dict
src/ffn_hooker.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import OrderedDict
4
+ from functools import partial
5
+
6
+ import diffusers
7
+ import torch
8
+ from torch import nn
9
+
10
+ import re
11
+
12
+
13
+ class FeedForwardHooker:
14
+ def __init__(
15
+ self,
16
+ pipeline: nn.Module,
17
+ regex: str,
18
+ dtype: torch.dtype,
19
+ masking: str,
20
+ dst: str,
21
+ epsilon: float = 0.0,
22
+ eps: float = 1e-6,
23
+ use_log: bool = False,
24
+ binary: bool = False,
25
+ ):
26
+ self.pipeline = pipeline
27
+ self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
28
+ self.logger = logging.getLogger(__name__)
29
+ self.dtype = dtype
30
+ self.regex = regex
31
+ self.hook_dict = {}
32
+ self.masking = masking
33
+ self.dst = dst
34
+ self.epsilon = epsilon
35
+ self.eps = eps
36
+ self.use_log = use_log
37
+ self.lambs = []
38
+ self.lambs_module_names = [] # store the module names for each lambda block
39
+ self.hook_counter = 0
40
+ self.module_neurons = OrderedDict()
41
+ self.binary = binary # default, need to discuss if we need to keep this attribute or not
42
+
43
+ def add_hooks_to_ff(self, hook_fn: callable):
44
+ total_hooks = 0
45
+ for name, module in self.net.named_modules():
46
+ name_last_word = name.split(".")[-1]
47
+ if "ff" in name_last_word:
48
+ if re.match(self.regex, name):
49
+ hook_fn_with_name = partial(hook_fn, name=name)
50
+ actual_module = module.net[0]
51
+ hook = actual_module.register_forward_hook(hook_fn_with_name, with_kwargs=True)
52
+ self.hook_dict[name] = hook
53
+
54
+ if isinstance(actual_module, diffusers.models.activations.GEGLU): # geglu
55
+ # due to the GEGLU chunking, we need to divide by 2
56
+ self.module_neurons[name] = actual_module.proj.out_features // 2
57
+ elif isinstance(actual_module, diffusers.models.activations.GELU): # gelu
58
+ self.module_neurons[name] = actual_module.proj.out_features
59
+ else:
60
+ raise NotImplementedError(f"Module {name} is not implemented, please check")
61
+ self.logger.info(f"Adding hook to {name}, neurons: {self.module_neurons[name]}")
62
+ total_hooks += 1
63
+ self.logger.info(f"Total hooks added: {total_hooks}")
64
+ return self.hook_dict
65
+
66
+ def add_hooks(self, init_value=1.0):
67
+ hook_fn = self.get_ff_masking_hook(init_value)
68
+ self.add_hooks_to_ff(hook_fn)
69
+ # initialize the lambda
70
+ self.lambs = [None] * len(self.hook_dict)
71
+ # initialize the lambda module names
72
+ self.lambs_module_names = [None] * len(self.hook_dict)
73
+
74
+ def clear_hooks(self):
75
+ """clear all hooks"""
76
+ for hook in self.hook_dict.values():
77
+ hook.remove()
78
+ self.hook_dict.clear()
79
+
80
+ def save(self, name: str = None):
81
+ if name is not None:
82
+ dst = os.path.join(os.path.dirname(self.dst), name)
83
+ else:
84
+ dst = self.dst
85
+ dst_dir = os.path.dirname(dst)
86
+ if not os.path.exists(dst_dir):
87
+ self.logger.info(f"Creating directory {dst_dir}")
88
+ os.makedirs(dst_dir)
89
+ torch.save(self.lambs, dst)
90
+
91
+ @property
92
+ def get_lambda_block_names(self):
93
+ return self.lambs_module_names
94
+
95
+ def load(self, device, threshold=2.5):
96
+ if os.path.exists(self.dst):
97
+ self.logger.info(f"loading lambda from {self.dst}")
98
+ self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
99
+ if self.binary:
100
+ # set binary masking for each lambda by using clamp
101
+ self.lambs = [(torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs]
102
+ else:
103
+ self.lambs = [torch.clamp(lamb, min=0.0) for lamb in self.lambs]
104
+ # self.lambs_module_names = [None for _ in self.lambs]
105
+ else:
106
+ self.logger.info("skipping loading, training from scratch")
107
+
108
+ def binarize(self, scope: str, ratio: float):
109
+ assert scope in ["local", "global"], "scope must be either local or global"
110
+ assert not self.binary, "binarization is not supported when using binary mask already"
111
+ if scope == "local":
112
+ # Local binarization
113
+ for i, lamb in enumerate(self.lambs):
114
+ num_heads = lamb.size(0)
115
+ num_activate_heads = int(num_heads * ratio)
116
+ # Sort the lambda values with stable sorting to maintain order for equal values
117
+ sorted_lamb, sorted_indices = torch.sort(lamb, descending=True, stable=True)
118
+ # Find the threshold value
119
+ threshold = sorted_lamb[num_activate_heads - 1]
120
+ # Create a mask based on the sorted indices
121
+ mask = torch.zeros_like(lamb)
122
+ mask[sorted_indices[:num_activate_heads]] = 1.0
123
+ # Binarize the lambda based on the threshold and the mask
124
+ self.lambs[i] = torch.where(lamb > threshold, torch.ones_like(lamb), mask)
125
+ else:
126
+ # Global binarization
127
+ all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
128
+ num_total = all_lambs.numel()
129
+ num_activate = int(num_total * ratio)
130
+ # Sort all lambda values globally with stable sorting
131
+ sorted_lambs, sorted_indices = torch.sort(all_lambs, descending=True, stable=True)
132
+ # Find the global threshold value
133
+ threshold = sorted_lambs[num_activate - 1]
134
+ # Create a global mask based on the sorted indices
135
+ global_mask = torch.zeros_like(all_lambs)
136
+ global_mask[sorted_indices[:num_activate]] = 1.0
137
+ # Binarize all lambdas based on the global threshold and mask
138
+ start_idx = 0
139
+ for i in range(len(self.lambs)):
140
+ end_idx = start_idx + self.lambs[i].numel()
141
+ lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
142
+ self.lambs[i] = torch.where(self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask)
143
+ start_idx = end_idx
144
+ self.binary = True
145
+
146
+ @staticmethod
147
+ def masking_fn(hidden_states, **kwargs):
148
+ hidden_states_dtype = hidden_states.dtype
149
+ lamb = kwargs["lamb"].view(1, 1, kwargs["lamb"].shape[0])
150
+ if kwargs.get("masking", None) == "sigmoid":
151
+ mask = torch.sigmoid(lamb)
152
+ elif kwargs.get("masking", None) == "binary":
153
+ mask = lamb
154
+ elif kwargs.get("masking", None) == "continues2binary":
155
+ # TODO: this might cause potential issue as it hard threshold at 0
156
+ mask = (lamb > 0).float()
157
+ elif kwargs.get("masking", None) == "no_masking":
158
+ mask = torch.ones_like(lamb)
159
+ else:
160
+ raise NotImplementedError
161
+ epsilon = kwargs.get("epsilon", 0.0)
162
+ hidden_states = hidden_states * mask + torch.randn_like(hidden_states) * epsilon * (1 - mask)
163
+ return hidden_states.to(hidden_states_dtype)
164
+
165
+ def get_ff_masking_hook(self, init_value=1.0):
166
+ """
167
+ Get a hook function to mask feed forward layer
168
+ """
169
+
170
+ def hook_fn(module, args, kwargs, output, name):
171
+ # initialize lambda with acual head dim in the first run
172
+ if self.lambs[self.hook_counter] is None:
173
+ self.lambs[self.hook_counter] = (
174
+ torch.ones(self.module_neurons[name], device=self.pipeline.device, dtype=self.dtype) * init_value
175
+ )
176
+ self.lambs[self.hook_counter].requires_grad = True
177
+ # load ff lambda module name for logging
178
+ self.lambs_module_names[self.hook_counter] = name
179
+
180
+ # perform masking
181
+ output = self.masking_fn(
182
+ output,
183
+ masking=self.masking,
184
+ lamb=self.lambs[self.hook_counter],
185
+ epsilon=self.epsilon,
186
+ eps=self.eps,
187
+ use_log=self.use_log,
188
+ )
189
+ self.hook_counter += 1
190
+ self.hook_counter %= len(self.lambs)
191
+ return output
192
+
193
+ return hook_fn
src/utils.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from copy import deepcopy
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from diffusers.models.activations import GEGLU, GELU
7
+ from src.cross_attn_hook import CrossAttentionExtractionHook
8
+ from src.ffn_hooker import FeedForwardHooker
9
+
10
+
11
+ # create dummy module for skip connection
12
+ class SkipConnection(torch.nn.Module):
13
+ def __init__(self):
14
+ super(SkipConnection, self).__init__()
15
+
16
+ def forward(*args, **kwargs):
17
+ return args[1]
18
+
19
+
20
+ def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
21
+ total_num_lambs = 0
22
+ num_activate_lambs = 0
23
+ binary = getattr(hooker, "binary", None) # if binary is not present, it will return None for ff_hooks
24
+ for lamb in hooker.lambs:
25
+ total_num_lambs += lamb.size(0)
26
+ if binary:
27
+ assert threshold is None, "threshold should be None for binary mask"
28
+ num_activate_lambs += lamb.sum().item()
29
+ else:
30
+ assert threshold is not None, "threshold must be provided for non-binary mask"
31
+ num_activate_lambs += (lamb >= threshold).sum().item()
32
+ return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
33
+
34
+
35
+ def create_pipeline(
36
+ pipe,
37
+ model_id,
38
+ device,
39
+ torch_dtype,
40
+ save_pt=None,
41
+ lambda_threshold: float = 1,
42
+ binary=True,
43
+ epsilon=0.0,
44
+ masking="binary",
45
+ attn_name="attn",
46
+ return_hooker=False,
47
+ scope=None,
48
+ ratio=None,
49
+ ):
50
+ """
51
+ create the pipeline and optionally load the saved mask
52
+ """
53
+ pipe.to(device)
54
+ pipe.vae.requires_grad_(False)
55
+ if hasattr(pipe, "unet"):
56
+ pipe.unet.requires_grad_(False)
57
+ else:
58
+ pipe.transformer.requires_grad_(False)
59
+ if save_pt:
60
+ # TODO should merge all the hooks checkpoint into one
61
+ if "ff.pt" in save_pt or "attn.pt" in save_pt:
62
+ save_pts = get_save_pts(save_pt)
63
+
64
+ cross_attn_hooker = CrossAttentionExtractionHook(
65
+ pipe,
66
+ model_name=model_id,
67
+ regex=".*",
68
+ dtype=torch_dtype,
69
+ head_num_filter=1,
70
+ masking=masking, # need to change to binary during inference
71
+ dst=save_pts["attn"],
72
+ epsilon=epsilon,
73
+ attn_name=attn_name,
74
+ binary=binary,
75
+ )
76
+ cross_attn_hooker.add_hooks(init_value=1)
77
+ ff_hooker = FeedForwardHooker(
78
+ pipe,
79
+ regex=".*",
80
+ dtype=torch_dtype,
81
+ masking=masking,
82
+ dst=save_pts["ff"],
83
+ epsilon=epsilon,
84
+ binary=binary,
85
+ )
86
+ ff_hooker.add_hooks(init_value=1)
87
+ norm_hooker = None
88
+
89
+ g_cpu = torch.Generator(torch.device(device)).manual_seed(1)
90
+ _ = pipe("abc", generator=g_cpu, num_inference_steps=1)
91
+ cross_attn_hooker.load(device=device, threshold=lambda_threshold)
92
+ ff_hooker.load(device=device, threshold=lambda_threshold)
93
+ if norm_hooker:
94
+ norm_hooker.load(device=device, threshold=lambda_threshold)
95
+ if scope == "local" or scope == "global":
96
+ if isinstance(ratio, float):
97
+ attn_hooker_ratio = ratio
98
+ ff_hooker_ratio = ratio
99
+ else:
100
+ attn_hooker_ratio, ff_hooker_ratio = ratio[0], ratio[1]
101
+
102
+ if norm_hooker:
103
+ if len(ratio) < 3:
104
+ raise ValueError("Need to provide ratio for norm layer")
105
+ norm_hooker_ratio = ratio[2]
106
+
107
+ cross_attn_hooker.binarize(scope, attn_hooker_ratio)
108
+ ff_hooker.binarize(scope, ff_hooker_ratio)
109
+ if norm_hooker:
110
+ norm_hooker.binarize(scope, norm_hooker_ratio)
111
+ hookers = [cross_attn_hooker, ff_hooker]
112
+ if norm_hooker:
113
+ hookers.append(norm_hooker)
114
+
115
+ if return_hooker:
116
+ return pipe, hookers
117
+ else:
118
+ return pipe
119
+
120
+
121
+ def linear_layer_pruning(module, lamb):
122
+ heads_to_keep = torch.nonzero(lamb).squeeze()
123
+ if len(heads_to_keep.shape) == 0:
124
+ # if only one head is kept, or none
125
+ heads_to_keep = heads_to_keep.unsqueeze(0)
126
+
127
+ modules_to_remove = [module.to_k, module.to_q, module.to_v]
128
+ new_heads = int(lamb.sum().item())
129
+
130
+ if new_heads == 0:
131
+ return SkipConnection()
132
+
133
+ for module_to_remove in modules_to_remove:
134
+ # get head dimension
135
+ inner_dim = module_to_remove.out_features // module.heads
136
+ # place holder for the rows to keep
137
+ rows_to_keep = torch.zeros(
138
+ module_to_remove.out_features, dtype=torch.bool, device=module_to_remove.weight.device
139
+ )
140
+
141
+ for idx in heads_to_keep:
142
+ rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
143
+
144
+ # overwrite the inner projection with masked projection
145
+ module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
146
+ if module_to_remove.bias is not None:
147
+ module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
148
+ module_to_remove.out_features = int(sum(rows_to_keep).item())
149
+
150
+ # Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
151
+ # with column masking, dim 1
152
+ if getattr(module, "to_out", None) is not None:
153
+ module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
154
+ module.to_out[0].in_features = int(sum(rows_to_keep).item())
155
+
156
+ # update parameters in the attention module
157
+ module.inner_dim = module.inner_dim // module.heads * new_heads
158
+ try:
159
+ module.query_dim = module.query_dim // module.heads * new_heads
160
+ module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
161
+ except:
162
+ pass
163
+ module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
164
+ module.heads = new_heads
165
+ return module
166
+
167
+
168
+ def ffn_linear_layer_pruning(module, lamb):
169
+ lambda_to_keep = torch.nonzero(lamb).squeeze()
170
+ if len(lambda_to_keep) == 0:
171
+ return SkipConnection()
172
+
173
+ num_lambda = len(lambda_to_keep)
174
+
175
+ if isinstance(module.net[0], GELU):
176
+ # linear layer weight remove before activation
177
+ module.net[0].proj.weight.data = module.net[0].proj.weight.data[lambda_to_keep, :]
178
+ module.net[0].proj.out_features = num_lambda
179
+ if module.net[0].proj.bias is not None:
180
+ module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
181
+
182
+ update_act = GELU(module.net[0].proj.in_features, num_lambda)
183
+ update_act.proj = module.net[0].proj
184
+ module.net[0] = update_act
185
+ elif isinstance(module.net[0], GEGLU):
186
+ output_feature = module.net[0].proj.out_features
187
+ module.net[0].proj.weight.data = torch.cat(
188
+ [
189
+ module.net[0].proj.weight.data[: output_feature // 2, :][lambda_to_keep, :],
190
+ module.net[0].proj.weight.data[output_feature // 2 :][lambda_to_keep, :],
191
+ ],
192
+ dim=0,
193
+ )
194
+ module.net[0].proj.out_features = num_lambda * 2
195
+ if module.net[0].proj.bias is not None:
196
+ module.net[0].proj.bias.data = torch.cat(
197
+ [
198
+ module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
199
+ module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
200
+ ]
201
+ )
202
+
203
+ update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
204
+ update_act.proj = module.net[0].proj
205
+ module.net[0] = update_act
206
+
207
+ # proj weight after activation
208
+ module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
209
+ module.net[2].in_features = num_lambda
210
+
211
+ return module
212
+
213
+
214
+ def get_save_pts(save_pt):
215
+ if "ff.pt" in save_pt:
216
+ ff_save_pt = deepcopy(save_pt) # avoid in-place operation
217
+ attn_save_pt = save_pt.split(os.sep)
218
+ attn_save_pt[-1] = attn_save_pt[-1].replace("ff", "attn")
219
+ attn_save_pt_output = os.sep.join(attn_save_pt)
220
+ attn_save_pt[-1] = attn_save_pt[-1].replace("attn", "norm")
221
+ norm_save_pt = os.sep.join(attn_save_pt)
222
+
223
+ return {
224
+ "ff": ff_save_pt,
225
+ "attn": attn_save_pt_output,
226
+ "norm": norm_save_pt,
227
+ }
228
+ else:
229
+ attn_save_pt = deepcopy(save_pt)
230
+ ff_save_pt = save_pt.split(os.sep)
231
+ ff_save_pt[-1] = ff_save_pt[-1].replace("attn", "ff")
232
+ ff_save_pt_output = os.sep.join(ff_save_pt)
233
+ ff_save_pt[-1] = ff_save_pt[-1].replace("ff", "norm")
234
+ norm_save_pt = os.sep.join(attn_save_pt)
235
+
236
+ return {
237
+ "ff": ff_save_pt_output,
238
+ "attn": attn_save_pt,
239
+ "norm": norm_save_pt,
240
+ }
241
+
242
+
243
+ def save_img(pipe, g_cpu, steps, prompt, save_path):
244
+ image = pipe(prompt, generator=g_cpu, num_inference_steps=steps)
245
+ image["images"][0].save(save_path)