Spaces:
Running
on
Zero
Running
on
Zero
zhangyang-0123
commited on
Commit
•
82d824b
1
Parent(s):
3964559
add ecodiff demo
Browse files- app.py +211 -141
- mask/attn.pt +0 -0
- mask/ff.pt +0 -0
- requirements.txt +4 -6
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/cross_attn_hook.cpython-310.pyc +0 -0
- src/__pycache__/ffn_hooker.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/cross_attn_hook.py +425 -0
- src/ffn_hooker.py +193 -0
- src/utils.py +245 -0
app.py
CHANGED
@@ -1,154 +1,224 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
import random
|
4 |
|
5 |
-
import spaces #[uncomment to use ZeroGPU]
|
6 |
-
from diffusers import DiffusionPipeline
|
7 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
MAX_SEED = np.iinfo(np.int32).max
|
21 |
-
MAX_IMAGE_SIZE = 1024
|
22 |
-
|
23 |
-
|
24 |
-
@spaces.GPU #[uncomment to use ZeroGPU]
|
25 |
-
def infer(
|
26 |
-
prompt,
|
27 |
-
negative_prompt,
|
28 |
-
seed,
|
29 |
-
randomize_seed,
|
30 |
-
width,
|
31 |
-
height,
|
32 |
-
guidance_scale,
|
33 |
-
num_inference_steps,
|
34 |
-
progress=gr.Progress(track_tqdm=True),
|
35 |
-
):
|
36 |
-
if randomize_seed:
|
37 |
-
seed = random.randint(0, MAX_SEED)
|
38 |
-
|
39 |
-
generator = torch.Generator().manual_seed(seed)
|
40 |
-
|
41 |
-
image = pipe(
|
42 |
-
prompt=prompt,
|
43 |
-
negative_prompt=negative_prompt,
|
44 |
-
guidance_scale=guidance_scale,
|
45 |
-
num_inference_steps=num_inference_steps,
|
46 |
-
width=width,
|
47 |
-
height=height,
|
48 |
-
generator=generator,
|
49 |
-
).images[0]
|
50 |
-
|
51 |
-
return image, seed
|
52 |
-
|
53 |
-
|
54 |
-
examples = [
|
55 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
56 |
-
"An astronaut riding a green horse",
|
57 |
-
"A delicious ceviche cheesecake slice",
|
58 |
-
]
|
59 |
-
|
60 |
-
css = """
|
61 |
-
#col-container {
|
62 |
-
margin: 0 auto;
|
63 |
-
max-width: 640px;
|
64 |
-
}
|
65 |
-
"""
|
66 |
-
|
67 |
-
with gr.Blocks(css=css) as demo:
|
68 |
-
with gr.Column(elem_id="col-container"):
|
69 |
-
gr.Markdown(" # Text-to-Image Gradio Template")
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
label="Prompt",
|
74 |
-
show_label=False,
|
75 |
-
max_lines=1,
|
76 |
-
placeholder="Enter your prompt",
|
77 |
-
container=False,
|
78 |
-
)
|
79 |
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
with gr.Accordion("Advanced Settings", open=False):
|
85 |
-
negative_prompt = gr.Text(
|
86 |
-
label="Negative prompt",
|
87 |
-
max_lines=1,
|
88 |
-
placeholder="Enter a negative prompt",
|
89 |
-
visible=False,
|
90 |
-
)
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
-
|
102 |
-
with gr.Row():
|
103 |
-
width = gr.Slider(
|
104 |
-
label="Width",
|
105 |
-
minimum=256,
|
106 |
-
maximum=MAX_IMAGE_SIZE,
|
107 |
-
step=32,
|
108 |
-
value=1024, # Replace with defaults that work for your model
|
109 |
-
)
|
110 |
-
|
111 |
-
height = gr.Slider(
|
112 |
-
label="Height",
|
113 |
-
minimum=256,
|
114 |
-
maximum=MAX_IMAGE_SIZE,
|
115 |
-
step=32,
|
116 |
-
value=1024, # Replace with defaults that work for your model
|
117 |
-
)
|
118 |
-
|
119 |
-
with gr.Row():
|
120 |
-
guidance_scale = gr.Slider(
|
121 |
-
label="Guidance scale",
|
122 |
-
minimum=0.0,
|
123 |
-
maximum=10.0,
|
124 |
-
step=0.1,
|
125 |
-
value=0.0, # Replace with defaults that work for your model
|
126 |
-
)
|
127 |
-
|
128 |
-
num_inference_steps = gr.Slider(
|
129 |
-
label="Number of inference steps",
|
130 |
-
minimum=1,
|
131 |
-
maximum=50,
|
132 |
-
step=1,
|
133 |
-
value=2, # Replace with defaults that work for your model
|
134 |
-
)
|
135 |
-
|
136 |
-
gr.Examples(examples=examples, inputs=[prompt])
|
137 |
-
gr.on(
|
138 |
-
triggers=[run_button.click, prompt.submit],
|
139 |
-
fn=infer,
|
140 |
-
inputs=[
|
141 |
-
prompt,
|
142 |
-
negative_prompt,
|
143 |
-
seed,
|
144 |
-
randomize_seed,
|
145 |
-
width,
|
146 |
-
height,
|
147 |
-
guidance_scale,
|
148 |
-
num_inference_steps,
|
149 |
-
],
|
150 |
-
outputs=[result, seed],
|
151 |
-
)
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
-
demo
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from dataclasses import dataclass
|
|
|
3 |
|
|
|
|
|
4 |
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from src.utils import (
|
8 |
+
create_pipeline,
|
9 |
+
calculate_mask_sparsity,
|
10 |
+
ffn_linear_layer_pruning,
|
11 |
+
linear_layer_pruning,
|
12 |
+
)
|
13 |
+
from diffusers import StableDiffusionXLPipeline
|
14 |
+
|
15 |
+
|
16 |
+
def get_model_param_summary(model, verbose=False):
|
17 |
+
params_dict = dict()
|
18 |
+
overall_params = 0
|
19 |
+
for name, params in model.named_parameters():
|
20 |
+
num_params = params.numel()
|
21 |
+
overall_params += num_params
|
22 |
+
if verbose:
|
23 |
+
print(f"GPU Memory Requirement for {name}: {params} MiB")
|
24 |
+
params_dict.update({name: num_params})
|
25 |
+
params_dict.update({"overall": overall_params})
|
26 |
+
return params_dict
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class GradioArgs:
|
31 |
+
ckpt: str = "./mask/ff.pt"
|
32 |
+
device: str = "cuda:0"
|
33 |
+
seed: list = None
|
34 |
+
prompt: str = None
|
35 |
+
mix_precision: str = "bf16"
|
36 |
+
num_intervention_steps: int = 50
|
37 |
+
model: str = "sdxl"
|
38 |
+
binary: bool = False
|
39 |
+
masking: str = "binary"
|
40 |
+
scope: str = "global"
|
41 |
+
ratio: list = None
|
42 |
+
width: int = None
|
43 |
+
height: int = None
|
44 |
+
epsilon: float = 0.0
|
45 |
+
lambda_threshold: float = 0.001
|
46 |
+
|
47 |
+
def __post_init__(self):
|
48 |
+
if self.seed is None:
|
49 |
+
self.seed = [44]
|
50 |
+
if self.ratio is None:
|
51 |
+
self.ratio = [0.68, 0.88]
|
52 |
+
|
53 |
+
|
54 |
+
def prune_model(pipe, hookers):
|
55 |
+
# remove parameters in attention blocks
|
56 |
+
cross_attn_hooker = hookers[0]
|
57 |
+
for name in tqdm(cross_attn_hooker.hook_dict.keys(), desc="Pruning attention layers"):
|
58 |
+
if getattr(pipe, "unet", None):
|
59 |
+
module = pipe.unet.get_submodule(name)
|
60 |
+
else:
|
61 |
+
module = pipe.transformer.get_submodule(name)
|
62 |
+
lamb = cross_attn_hooker.lambs[cross_attn_hooker.lambs_module_names.index(name)]
|
63 |
+
assert module.heads == lamb.shape[0]
|
64 |
+
module = linear_layer_pruning(module, lamb)
|
65 |
+
|
66 |
+
parent_module_name, child_name = name.rsplit(".", 1)
|
67 |
+
if getattr(pipe, "unet", None):
|
68 |
+
parent_module = pipe.unet.get_submodule(parent_module_name)
|
69 |
+
else:
|
70 |
+
parent_module = pipe.transformer.get_submodule(parent_module_name)
|
71 |
+
setattr(parent_module, child_name, module)
|
72 |
+
|
73 |
+
# remove parameters in ffn blocks
|
74 |
+
ffn_hook = hookers[1]
|
75 |
+
for name in tqdm(ffn_hook.hook_dict.keys(), desc="Pruning on FFN linear lazer"):
|
76 |
+
if getattr(pipe, "unet", None):
|
77 |
+
module = pipe.unet.get_submodule(name)
|
78 |
+
else:
|
79 |
+
module = pipe.transformer.get_submodule(name)
|
80 |
+
lamb = ffn_hook.lambs[ffn_hook.lambs_module_names.index(name)]
|
81 |
+
module = ffn_linear_layer_pruning(module, lamb)
|
82 |
+
|
83 |
+
parent_module_name, child_name = name.rsplit(".", 1)
|
84 |
+
if getattr(pipe, "unet", None):
|
85 |
+
parent_module = pipe.unet.get_submodule(parent_module_name)
|
86 |
+
else:
|
87 |
+
parent_module = pipe.transformer.get_submodule(parent_module_name)
|
88 |
+
setattr(parent_module, child_name, module)
|
89 |
+
|
90 |
+
cross_attn_hooker.clear_hooks()
|
91 |
+
ffn_hook.clear_hooks()
|
92 |
+
return pipe
|
93 |
+
|
94 |
+
|
95 |
+
def binary_mask_eval(args):
|
96 |
+
# load sdxl model
|
97 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
98 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
99 |
+
).to(args.device)
|
100 |
+
|
101 |
+
device = args.device
|
102 |
+
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
|
103 |
+
mask_pipe, hookers = create_pipeline(
|
104 |
+
pipe,
|
105 |
+
args.model,
|
106 |
+
device,
|
107 |
+
torch_dtype,
|
108 |
+
args.ckpt,
|
109 |
+
binary=args.binary,
|
110 |
+
lambda_threshold=args.lambda_threshold,
|
111 |
+
epsilon=args.epsilon,
|
112 |
+
masking=args.masking,
|
113 |
+
return_hooker=True,
|
114 |
+
scope=args.scope,
|
115 |
+
ratio=args.ratio,
|
116 |
+
)
|
117 |
|
118 |
+
# Print mask sparsity info
|
119 |
+
threshold = None if args.binary else args.lambda_threshold
|
120 |
+
threshold = None if args.scope is not None else threshold
|
121 |
+
name = ["ff", "attn"]
|
122 |
+
for n, hooker in zip(name, hookers):
|
123 |
+
total_num_heads, num_activate_heads, mask_sparsity = calculate_mask_sparsity(hooker, threshold)
|
124 |
+
print(f"model: {args.model}, {n} masking: {args.masking}")
|
125 |
+
print(
|
126 |
+
f"total num heads: {total_num_heads},"
|
127 |
+
+ f"num activate heads: {num_activate_heads}, mask sparsity: {mask_sparsity}"
|
128 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
+
# Prune the model
|
131 |
+
pruned_pipe = prune_model(mask_pipe, hookers)
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
+
# reload the original model
|
134 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
135 |
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
|
136 |
+
).to(args.device)
|
137 |
|
138 |
+
# get model param summary
|
139 |
+
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
|
140 |
+
print(f"pruned model param: {get_model_param_summary(pruned_pipe.unet)['overall']}")
|
141 |
+
print("prune complete")
|
142 |
+
return pipe, pruned_pipe
|
143 |
+
|
144 |
+
|
145 |
+
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
|
146 |
+
# Run the model and return images directly
|
147 |
+
g_cpu = torch.Generator("cuda:0").manual_seed(seed)
|
148 |
+
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
149 |
+
g_cpu = torch.Generator("cuda:0").manual_seed(seed)
|
150 |
+
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
|
151 |
+
return original_image, ecodiff_image
|
152 |
+
|
153 |
+
|
154 |
+
def on_prune_click(prompt, seed, steps):
|
155 |
+
args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
|
156 |
+
pipe, pruned_pipe = binary_mask_eval(args)
|
157 |
+
return pipe, pruned_pipe, [("Model Initialized", "green")]
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
|
161 |
+
original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
|
162 |
+
return original_image, ecodiff_image
|
163 |
+
|
164 |
+
|
165 |
+
def create_demo():
|
166 |
+
with gr.Blocks() as demo:
|
167 |
+
gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
|
168 |
+
with gr.Row():
|
169 |
+
gr.Markdown(
|
170 |
+
"""
|
171 |
+
# 🚧 Under Construction 🚧
|
172 |
+
This demo is currently being developed and may not be fully functional. More models and pruning ratios will be supported soon.
|
173 |
+
The current pruned model checkpoint is not optimal and does not provide the best performance.
|
174 |
+
|
175 |
+
**Note: Please first initialize the model before generating images.**
|
176 |
+
"""
|
177 |
)
|
178 |
+
with gr.Row():
|
179 |
+
model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
|
180 |
+
pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
|
181 |
+
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
|
182 |
+
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
|
183 |
+
with gr.Row():
|
184 |
+
prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
|
185 |
+
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
|
186 |
+
steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1)
|
187 |
+
generate_btn = gr.Button("Generate Images")
|
188 |
+
gr.Examples(
|
189 |
+
examples=[
|
190 |
+
"A clock tower floating in a sea of clouds",
|
191 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
192 |
+
"An astronaut riding a green horse",
|
193 |
+
"A delicious ceviche cheesecake slice",
|
194 |
+
],
|
195 |
+
inputs=[prompt],
|
196 |
+
)
|
197 |
+
with gr.Row():
|
198 |
+
original_output = gr.Image(label="Original Output")
|
199 |
+
ecodiff_output = gr.Image(label="EcoDiff Output")
|
200 |
+
|
201 |
+
pipe_state = gr.State(None)
|
202 |
+
pruned_pipe_state = gr.State(None)
|
203 |
+
prompt.submit(
|
204 |
+
fn=on_generate_click,
|
205 |
+
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
|
206 |
+
outputs=[original_output, ecodiff_output],
|
207 |
+
)
|
208 |
+
prune_btn.click(
|
209 |
+
fn=on_prune_click,
|
210 |
+
inputs=[prompt, seed, steps],
|
211 |
+
outputs=[pipe_state, pruned_pipe_state, status_label],
|
212 |
+
)
|
213 |
+
generate_btn.click(
|
214 |
+
fn=on_generate_click,
|
215 |
+
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
|
216 |
+
outputs=[original_output, ecodiff_output],
|
217 |
+
)
|
218 |
+
|
219 |
+
return demo
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
if __name__ == "__main__":
|
223 |
+
demo = create_demo()
|
224 |
+
demo.launch(share=True)
|
mask/attn.pt
ADDED
Binary file (40.5 kB). View file
|
|
mask/ff.pt
ADDED
Binary file (686 kB). View file
|
|
requirements.txt
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
transformers
|
6 |
-
xformers
|
|
|
1 |
+
diffusers==0.31.0
|
2 |
+
torch==2.4.1
|
3 |
+
transformers==4.45.2
|
4 |
+
accelerate==0.33.0
|
|
|
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (198 Bytes). View file
|
|
src/__pycache__/cross_attn_hook.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
src/__pycache__/ffn_hooker.cpython-310.pyc
ADDED
Binary file (6.4 kB). View file
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (5.53 kB). View file
|
|
src/cross_attn_hook.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from collections import OrderedDict
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import re
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from diffusers.models.attention_processor import Attention
|
16 |
+
from diffusers.utils import deprecate
|
17 |
+
|
18 |
+
|
19 |
+
def scaled_dot_product_attention_atten_weight_only(
|
20 |
+
query, key, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
|
21 |
+
) -> torch.Tensor:
|
22 |
+
L, S = query.size(-2), key.size(-2)
|
23 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
24 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
25 |
+
if is_causal:
|
26 |
+
assert attn_mask is None
|
27 |
+
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
|
28 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
29 |
+
attn_bias.to(query.dtype)
|
30 |
+
|
31 |
+
if attn_mask is not None:
|
32 |
+
if attn_mask.dtype == torch.bool:
|
33 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
34 |
+
else:
|
35 |
+
attn_bias += attn_mask
|
36 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
37 |
+
attn_weight += attn_bias
|
38 |
+
attn_weight = torch.softmax(attn_weight, dim=-1)
|
39 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
40 |
+
return attn_weight
|
41 |
+
|
42 |
+
|
43 |
+
def apply_rope(xq, xk, freqs_cis):
|
44 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
45 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
46 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
47 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
48 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
49 |
+
|
50 |
+
|
51 |
+
def masking_fn(hidden_states, kwargs):
|
52 |
+
lamb = kwargs["lamb"].view(1, kwargs["lamb"].shape[0], 1, 1)
|
53 |
+
if kwargs.get("masking", None) == "sigmoid":
|
54 |
+
mask = torch.sigmoid(lamb)
|
55 |
+
elif kwargs.get("masking", None) == "binary":
|
56 |
+
mask = lamb
|
57 |
+
elif kwargs.get("masking", None) == "continues2binary":
|
58 |
+
# TODO: this might cause potential issue as it hard threshold at 0
|
59 |
+
mask = (lamb > 0).float()
|
60 |
+
elif kwargs.get("masking", None) == "no_masking":
|
61 |
+
mask = torch.ones_like(lamb)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError
|
64 |
+
epsilon = kwargs.get("epsilon", 0.0)
|
65 |
+
hidden_states = hidden_states * mask + torch.randn_like(hidden_states) * epsilon * (1 - mask)
|
66 |
+
return hidden_states
|
67 |
+
|
68 |
+
|
69 |
+
class AttnProcessor2_0_Masking:
|
70 |
+
r"""
|
71 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self):
|
75 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
76 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
77 |
+
|
78 |
+
def __call__(
|
79 |
+
self,
|
80 |
+
attn: Attention,
|
81 |
+
hidden_states: torch.Tensor,
|
82 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
83 |
+
attention_mask: Optional[torch.Tensor] = None,
|
84 |
+
temb: Optional[torch.Tensor] = None,
|
85 |
+
*args,
|
86 |
+
**kwargs,
|
87 |
+
):
|
88 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
89 |
+
deprecation_message = (
|
90 |
+
"The `scale` argument is deprecated and will be ignored. "
|
91 |
+
"Please remove it, as passing it will raise an error "
|
92 |
+
"in the future. `scale` should directly be passed while "
|
93 |
+
"calling the underlying pipeline component i.e., via "
|
94 |
+
"`cross_attention_kwargs`."
|
95 |
+
)
|
96 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
97 |
+
|
98 |
+
residual = hidden_states
|
99 |
+
if attn.spatial_norm is not None:
|
100 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
101 |
+
|
102 |
+
input_ndim = hidden_states.ndim
|
103 |
+
|
104 |
+
if input_ndim == 4:
|
105 |
+
batch_size, channel, height, width = hidden_states.shape
|
106 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
107 |
+
|
108 |
+
batch_size, sequence_length, _ = (
|
109 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
110 |
+
)
|
111 |
+
|
112 |
+
if attention_mask is not None:
|
113 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
114 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
115 |
+
# (batch, heads, source_length, target_length)
|
116 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
117 |
+
|
118 |
+
if attn.group_norm is not None:
|
119 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
120 |
+
|
121 |
+
query = attn.to_q(hidden_states)
|
122 |
+
|
123 |
+
if encoder_hidden_states is None:
|
124 |
+
encoder_hidden_states = hidden_states
|
125 |
+
elif attn.norm_cross:
|
126 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
127 |
+
|
128 |
+
key = attn.to_k(encoder_hidden_states)
|
129 |
+
value = attn.to_v(encoder_hidden_states)
|
130 |
+
|
131 |
+
inner_dim = key.shape[-1]
|
132 |
+
head_dim = inner_dim // attn.heads
|
133 |
+
|
134 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
135 |
+
|
136 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
137 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
138 |
+
|
139 |
+
if getattr(attn, "norm_q", None) is not None:
|
140 |
+
query = attn.norm_q(query)
|
141 |
+
|
142 |
+
if getattr(attn, "norm_k", None) is not None:
|
143 |
+
key = attn.norm_k(key)
|
144 |
+
|
145 |
+
hidden_states = F.scaled_dot_product_attention(
|
146 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
147 |
+
)
|
148 |
+
|
149 |
+
if kwargs.get("return_attention", True):
|
150 |
+
# add the attention output from F.scaled_dot_product_attention
|
151 |
+
attn_weight = scaled_dot_product_attention_atten_weight_only(
|
152 |
+
query, key, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
153 |
+
)
|
154 |
+
hidden_states_aft_attention_ops = hidden_states.clone()
|
155 |
+
attn_weight_old = attn_weight.to(hidden_states.device).clone()
|
156 |
+
else:
|
157 |
+
hidden_states_aft_attention_ops = None
|
158 |
+
attn_weight_old = None
|
159 |
+
|
160 |
+
# masking for the hidden_states after the attention ops
|
161 |
+
if kwargs.get("lamb", None) is not None:
|
162 |
+
hidden_states = masking_fn(hidden_states, kwargs)
|
163 |
+
|
164 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
165 |
+
hidden_states = hidden_states.to(query.dtype)
|
166 |
+
|
167 |
+
# linear proj
|
168 |
+
hidden_states = attn.to_out[0](hidden_states)
|
169 |
+
# dropout
|
170 |
+
hidden_states = attn.to_out[1](hidden_states)
|
171 |
+
|
172 |
+
if input_ndim == 4:
|
173 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
174 |
+
|
175 |
+
if attn.residual_connection:
|
176 |
+
hidden_states = hidden_states + residual
|
177 |
+
|
178 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
179 |
+
|
180 |
+
return hidden_states, hidden_states_aft_attention_ops, attn_weight_old
|
181 |
+
|
182 |
+
class BaseCrossAttentionHooker:
|
183 |
+
def __init__(self, pipeline, regex, dtype, head_num_filter, masking, model_name, attn_name, use_log, eps):
|
184 |
+
self.pipeline = pipeline
|
185 |
+
# unet for SD2 SDXL, transformer for SD3, FLUX DIT
|
186 |
+
self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
|
187 |
+
self.model_name = model_name
|
188 |
+
self.module_heads = OrderedDict()
|
189 |
+
self.masking = masking
|
190 |
+
self.hook_dict = {}
|
191 |
+
self.regex = regex
|
192 |
+
self.dtype = dtype
|
193 |
+
self.head_num_filter = head_num_filter
|
194 |
+
self.attn_name = attn_name
|
195 |
+
self.logger = logging.getLogger(__name__)
|
196 |
+
self.use_log = use_log # use log parameter to control hard_discrete
|
197 |
+
self.eps = eps
|
198 |
+
|
199 |
+
def add_hooks_to_cross_attention(self, hook_fn: callable):
|
200 |
+
"""
|
201 |
+
Add forward hooks to every cross attention
|
202 |
+
:param hook_fn: a callable to be added to torch nn module as a hook
|
203 |
+
:return:
|
204 |
+
"""
|
205 |
+
total_hooks = 0
|
206 |
+
for name, module in self.net.named_modules():
|
207 |
+
name_last_word = name.split(".")[-1]
|
208 |
+
if self.attn_name in name_last_word:
|
209 |
+
if re.match(self.regex, name):
|
210 |
+
hook_fn = partial(hook_fn, name=name)
|
211 |
+
hook = module.register_forward_hook(hook_fn, with_kwargs=True)
|
212 |
+
self.hook_dict[name] = hook
|
213 |
+
self.module_heads[name] = module.heads
|
214 |
+
self.logger.info(f"Adding hook to {name}, module.heads: {module.heads}")
|
215 |
+
total_hooks += 1
|
216 |
+
self.logger.info(f"Total hooks added: {total_hooks}")
|
217 |
+
|
218 |
+
def clear_hooks(self):
|
219 |
+
"""clear all hooks"""
|
220 |
+
for hook in self.hook_dict.values():
|
221 |
+
hook.remove()
|
222 |
+
self.hook_dict.clear()
|
223 |
+
|
224 |
+
|
225 |
+
class CrossAttentionExtractionHook(BaseCrossAttentionHooker):
|
226 |
+
def __init__(
|
227 |
+
self,
|
228 |
+
pipeline,
|
229 |
+
dtype,
|
230 |
+
head_num_filter,
|
231 |
+
masking,
|
232 |
+
dst,
|
233 |
+
regex=None,
|
234 |
+
epsilon=0.0,
|
235 |
+
binary=False,
|
236 |
+
return_attention=False,
|
237 |
+
model_name="sdxl",
|
238 |
+
attn_name="attn",
|
239 |
+
use_log=False,
|
240 |
+
eps=1e-6,
|
241 |
+
):
|
242 |
+
super().__init__(
|
243 |
+
pipeline,
|
244 |
+
regex,
|
245 |
+
dtype,
|
246 |
+
head_num_filter,
|
247 |
+
masking=masking,
|
248 |
+
model_name=model_name,
|
249 |
+
attn_name=attn_name,
|
250 |
+
use_log=use_log,
|
251 |
+
eps=eps,
|
252 |
+
)
|
253 |
+
self.attention_processor = AttnProcessor2_0_Masking()
|
254 |
+
self.lambs = []
|
255 |
+
self.lambs_module_names = []
|
256 |
+
self.cross_attn = []
|
257 |
+
self.hook_counter = 0
|
258 |
+
self.device = self.pipeline.unet.device if hasattr(self.pipeline, "unet") else self.pipeline.transformer.device
|
259 |
+
self.dst = dst
|
260 |
+
self.epsilon = epsilon
|
261 |
+
self.binary = binary
|
262 |
+
self.return_attention = return_attention
|
263 |
+
self.model_name = model_name
|
264 |
+
|
265 |
+
def clean_cross_attn(self):
|
266 |
+
self.cross_attn = []
|
267 |
+
|
268 |
+
def validate_dst(self):
|
269 |
+
if os.path.exists(self.dst):
|
270 |
+
raise ValueError(f"Destination {self.dst} already exists")
|
271 |
+
|
272 |
+
def save(self, name: str = None):
|
273 |
+
if name is not None:
|
274 |
+
dst = os.path.join(os.path.dirname(self.dst), name)
|
275 |
+
else:
|
276 |
+
dst = self.dst
|
277 |
+
dst_dir = os.path.dirname(dst)
|
278 |
+
if not os.path.exists(dst_dir):
|
279 |
+
self.logger.info(f"Creating directory {dst_dir}")
|
280 |
+
os.makedirs(dst_dir)
|
281 |
+
torch.save(self.lambs, dst)
|
282 |
+
|
283 |
+
@property
|
284 |
+
def get_lambda_block_names(self):
|
285 |
+
return self.lambs_module_names
|
286 |
+
|
287 |
+
def load(self, device, threshold=2.5):
|
288 |
+
if os.path.exists(self.dst):
|
289 |
+
self.logger.info(f"loading lambda from {self.dst}")
|
290 |
+
self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
|
291 |
+
if self.binary:
|
292 |
+
# set binary masking for each lambda by using clamp
|
293 |
+
self.lambs = [(torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs]
|
294 |
+
else:
|
295 |
+
self.logger.info("skipping loading, training from scratch")
|
296 |
+
|
297 |
+
def binarize(self, scope: str, ratio: float):
|
298 |
+
assert scope in ["local", "global"], "scope must be either local or global"
|
299 |
+
assert not self.binary, "binarization is not supported when using binary mask already"
|
300 |
+
if scope == "local":
|
301 |
+
# Local binarization
|
302 |
+
for i, lamb in enumerate(self.lambs):
|
303 |
+
num_heads = lamb.size(0)
|
304 |
+
num_activate_heads = int(num_heads * ratio)
|
305 |
+
# Sort the lambda values with stable sorting to maintain order for equal values
|
306 |
+
sorted_lamb, sorted_indices = torch.sort(lamb, descending=True, stable=True)
|
307 |
+
# Find the threshold value
|
308 |
+
threshold = sorted_lamb[num_activate_heads - 1]
|
309 |
+
# Create a mask based on the sorted indices
|
310 |
+
mask = torch.zeros_like(lamb)
|
311 |
+
mask[sorted_indices[:num_activate_heads]] = 1.0
|
312 |
+
# Binarize the lambda based on the threshold and the mask
|
313 |
+
self.lambs[i] = torch.where(lamb > threshold, torch.ones_like(lamb), mask)
|
314 |
+
else:
|
315 |
+
# Global binarization
|
316 |
+
all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
|
317 |
+
num_total = all_lambs.numel()
|
318 |
+
num_activate = int(num_total * ratio)
|
319 |
+
# Sort all lambda values globally with stable sorting
|
320 |
+
sorted_lambs, sorted_indices = torch.sort(all_lambs, descending=True, stable=True)
|
321 |
+
# Find the global threshold value
|
322 |
+
threshold = sorted_lambs[num_activate - 1]
|
323 |
+
# Create a global mask based on the sorted indices
|
324 |
+
global_mask = torch.zeros_like(all_lambs)
|
325 |
+
global_mask[sorted_indices[:num_activate]] = 1.0
|
326 |
+
# Binarize all lambdas based on the global threshold and mask
|
327 |
+
start_idx = 0
|
328 |
+
for i in range(len(self.lambs)):
|
329 |
+
end_idx = start_idx + self.lambs[i].numel()
|
330 |
+
lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
|
331 |
+
self.lambs[i] = torch.where(self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask)
|
332 |
+
start_idx = end_idx
|
333 |
+
self.binary = True
|
334 |
+
|
335 |
+
def bizarize_threshold(self, threshold: float):
|
336 |
+
"""
|
337 |
+
Binarize lambda values based on a predefined threshold.
|
338 |
+
:param threshold: The threshold value for binarization
|
339 |
+
"""
|
340 |
+
assert not self.binary, "Binarization is not supported when using binary mask already"
|
341 |
+
|
342 |
+
for i in range(len(self.lambs)):
|
343 |
+
self.lambs[i] = (self.lambs[i] >= threshold).float()
|
344 |
+
|
345 |
+
self.binary = True
|
346 |
+
|
347 |
+
def get_cross_attn_extraction_hook(self, init_value=1.0):
|
348 |
+
"""get a hook function to extract cross attention"""
|
349 |
+
|
350 |
+
# the reason to use a function inside a function is to save the extracted cross attention
|
351 |
+
def hook_fn(module, args, kwargs, output, name):
|
352 |
+
# initialize lambda with acual head dim in the first run
|
353 |
+
if self.lambs[self.hook_counter] is None:
|
354 |
+
self.lambs[self.hook_counter] = (
|
355 |
+
torch.ones(module.heads, device=self.pipeline.device, dtype=self.dtype) * init_value
|
356 |
+
)
|
357 |
+
# Only set requires_grad to True when the head number is larger than the filter
|
358 |
+
if self.head_num_filter <= module.heads:
|
359 |
+
self.lambs[self.hook_counter].requires_grad = True
|
360 |
+
|
361 |
+
# load attn lambda module name for logging
|
362 |
+
self.lambs_module_names[self.hook_counter] = name
|
363 |
+
|
364 |
+
hidden_states, _, attention_output = self.attention_processor(
|
365 |
+
module,
|
366 |
+
args[0],
|
367 |
+
encoder_hidden_states=kwargs["encoder_hidden_states"],
|
368 |
+
attention_mask=kwargs["attention_mask"],
|
369 |
+
lamb=self.lambs[self.hook_counter],
|
370 |
+
masking=self.masking,
|
371 |
+
epsilon=self.epsilon,
|
372 |
+
return_attention=self.return_attention,
|
373 |
+
use_log=self.use_log,
|
374 |
+
eps=self.eps,
|
375 |
+
)
|
376 |
+
if attention_output is not None:
|
377 |
+
self.cross_attn.append(attention_output)
|
378 |
+
self.hook_counter += 1
|
379 |
+
self.hook_counter %= len(self.lambs)
|
380 |
+
return hidden_states
|
381 |
+
|
382 |
+
return hook_fn
|
383 |
+
|
384 |
+
def add_hooks(self, init_value=1.0):
|
385 |
+
hook_fn = self.get_cross_attn_extraction_hook(init_value)
|
386 |
+
self.add_hooks_to_cross_attention(hook_fn)
|
387 |
+
# initialize the lambda
|
388 |
+
self.lambs = [None] * len(self.module_heads)
|
389 |
+
# initialize the lambda module names
|
390 |
+
self.lambs_module_names = [None] * len(self.module_heads)
|
391 |
+
|
392 |
+
def get_process_cross_attn_result(self, text_seq_length, timestep: int = -1):
|
393 |
+
if isinstance(timestep, str):
|
394 |
+
timestep = int(timestep)
|
395 |
+
# num_lambda_block contains lambda (head masking)
|
396 |
+
num_lambda_block = len(self.lambs)
|
397 |
+
|
398 |
+
# get the start and end position of the timestep
|
399 |
+
start_pos = timestep * num_lambda_block
|
400 |
+
end_pos = (timestep + 1) * num_lambda_block
|
401 |
+
if end_pos > len(self.cross_attn):
|
402 |
+
raise ValueError(f"timestep {timestep} is out of range")
|
403 |
+
|
404 |
+
# list[cross_attn_map] num_layer x [batch, num_heads, seq_vis_tokens, seq_text_tokens]
|
405 |
+
attn_maps = self.cross_attn[start_pos:end_pos]
|
406 |
+
|
407 |
+
def heatmap(attn_list, attn_idx, head_idx, text_idx):
|
408 |
+
# only select second element in the tuple (with text guided attention)
|
409 |
+
# layer_idx, 1, head_idx, seq_vis_tokens, seq_text_tokens
|
410 |
+
map = attn_list[attn_idx][1][head_idx][:][:, text_idx]
|
411 |
+
# get the size of the heatmap
|
412 |
+
size = int(map.shape[0] ** 0.5)
|
413 |
+
map = map.view(size, size, 1)
|
414 |
+
data = map.cpu().float().numpy()
|
415 |
+
return data
|
416 |
+
|
417 |
+
output_dict = {}
|
418 |
+
for lambda_block_idx, lambda_block_name in zip(range(num_lambda_block), self.lambs_module_names):
|
419 |
+
data_list = []
|
420 |
+
for head_idx in range(len(self.lambs[lambda_block_idx])):
|
421 |
+
for token_idx in range(text_seq_length):
|
422 |
+
# number of heatmap is equal to the number of tokens in the text sequence X number of heads
|
423 |
+
data_list.append(heatmap(attn_maps, lambda_block_idx, head_idx, token_idx))
|
424 |
+
output_dict[lambda_block_name] = {"attn_map": data_list, "lambda": self.lambs[lambda_block_idx]}
|
425 |
+
return output_dict
|
src/ffn_hooker.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from collections import OrderedDict
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import diffusers
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
import re
|
11 |
+
|
12 |
+
|
13 |
+
class FeedForwardHooker:
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
pipeline: nn.Module,
|
17 |
+
regex: str,
|
18 |
+
dtype: torch.dtype,
|
19 |
+
masking: str,
|
20 |
+
dst: str,
|
21 |
+
epsilon: float = 0.0,
|
22 |
+
eps: float = 1e-6,
|
23 |
+
use_log: bool = False,
|
24 |
+
binary: bool = False,
|
25 |
+
):
|
26 |
+
self.pipeline = pipeline
|
27 |
+
self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
|
28 |
+
self.logger = logging.getLogger(__name__)
|
29 |
+
self.dtype = dtype
|
30 |
+
self.regex = regex
|
31 |
+
self.hook_dict = {}
|
32 |
+
self.masking = masking
|
33 |
+
self.dst = dst
|
34 |
+
self.epsilon = epsilon
|
35 |
+
self.eps = eps
|
36 |
+
self.use_log = use_log
|
37 |
+
self.lambs = []
|
38 |
+
self.lambs_module_names = [] # store the module names for each lambda block
|
39 |
+
self.hook_counter = 0
|
40 |
+
self.module_neurons = OrderedDict()
|
41 |
+
self.binary = binary # default, need to discuss if we need to keep this attribute or not
|
42 |
+
|
43 |
+
def add_hooks_to_ff(self, hook_fn: callable):
|
44 |
+
total_hooks = 0
|
45 |
+
for name, module in self.net.named_modules():
|
46 |
+
name_last_word = name.split(".")[-1]
|
47 |
+
if "ff" in name_last_word:
|
48 |
+
if re.match(self.regex, name):
|
49 |
+
hook_fn_with_name = partial(hook_fn, name=name)
|
50 |
+
actual_module = module.net[0]
|
51 |
+
hook = actual_module.register_forward_hook(hook_fn_with_name, with_kwargs=True)
|
52 |
+
self.hook_dict[name] = hook
|
53 |
+
|
54 |
+
if isinstance(actual_module, diffusers.models.activations.GEGLU): # geglu
|
55 |
+
# due to the GEGLU chunking, we need to divide by 2
|
56 |
+
self.module_neurons[name] = actual_module.proj.out_features // 2
|
57 |
+
elif isinstance(actual_module, diffusers.models.activations.GELU): # gelu
|
58 |
+
self.module_neurons[name] = actual_module.proj.out_features
|
59 |
+
else:
|
60 |
+
raise NotImplementedError(f"Module {name} is not implemented, please check")
|
61 |
+
self.logger.info(f"Adding hook to {name}, neurons: {self.module_neurons[name]}")
|
62 |
+
total_hooks += 1
|
63 |
+
self.logger.info(f"Total hooks added: {total_hooks}")
|
64 |
+
return self.hook_dict
|
65 |
+
|
66 |
+
def add_hooks(self, init_value=1.0):
|
67 |
+
hook_fn = self.get_ff_masking_hook(init_value)
|
68 |
+
self.add_hooks_to_ff(hook_fn)
|
69 |
+
# initialize the lambda
|
70 |
+
self.lambs = [None] * len(self.hook_dict)
|
71 |
+
# initialize the lambda module names
|
72 |
+
self.lambs_module_names = [None] * len(self.hook_dict)
|
73 |
+
|
74 |
+
def clear_hooks(self):
|
75 |
+
"""clear all hooks"""
|
76 |
+
for hook in self.hook_dict.values():
|
77 |
+
hook.remove()
|
78 |
+
self.hook_dict.clear()
|
79 |
+
|
80 |
+
def save(self, name: str = None):
|
81 |
+
if name is not None:
|
82 |
+
dst = os.path.join(os.path.dirname(self.dst), name)
|
83 |
+
else:
|
84 |
+
dst = self.dst
|
85 |
+
dst_dir = os.path.dirname(dst)
|
86 |
+
if not os.path.exists(dst_dir):
|
87 |
+
self.logger.info(f"Creating directory {dst_dir}")
|
88 |
+
os.makedirs(dst_dir)
|
89 |
+
torch.save(self.lambs, dst)
|
90 |
+
|
91 |
+
@property
|
92 |
+
def get_lambda_block_names(self):
|
93 |
+
return self.lambs_module_names
|
94 |
+
|
95 |
+
def load(self, device, threshold=2.5):
|
96 |
+
if os.path.exists(self.dst):
|
97 |
+
self.logger.info(f"loading lambda from {self.dst}")
|
98 |
+
self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
|
99 |
+
if self.binary:
|
100 |
+
# set binary masking for each lambda by using clamp
|
101 |
+
self.lambs = [(torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs]
|
102 |
+
else:
|
103 |
+
self.lambs = [torch.clamp(lamb, min=0.0) for lamb in self.lambs]
|
104 |
+
# self.lambs_module_names = [None for _ in self.lambs]
|
105 |
+
else:
|
106 |
+
self.logger.info("skipping loading, training from scratch")
|
107 |
+
|
108 |
+
def binarize(self, scope: str, ratio: float):
|
109 |
+
assert scope in ["local", "global"], "scope must be either local or global"
|
110 |
+
assert not self.binary, "binarization is not supported when using binary mask already"
|
111 |
+
if scope == "local":
|
112 |
+
# Local binarization
|
113 |
+
for i, lamb in enumerate(self.lambs):
|
114 |
+
num_heads = lamb.size(0)
|
115 |
+
num_activate_heads = int(num_heads * ratio)
|
116 |
+
# Sort the lambda values with stable sorting to maintain order for equal values
|
117 |
+
sorted_lamb, sorted_indices = torch.sort(lamb, descending=True, stable=True)
|
118 |
+
# Find the threshold value
|
119 |
+
threshold = sorted_lamb[num_activate_heads - 1]
|
120 |
+
# Create a mask based on the sorted indices
|
121 |
+
mask = torch.zeros_like(lamb)
|
122 |
+
mask[sorted_indices[:num_activate_heads]] = 1.0
|
123 |
+
# Binarize the lambda based on the threshold and the mask
|
124 |
+
self.lambs[i] = torch.where(lamb > threshold, torch.ones_like(lamb), mask)
|
125 |
+
else:
|
126 |
+
# Global binarization
|
127 |
+
all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
|
128 |
+
num_total = all_lambs.numel()
|
129 |
+
num_activate = int(num_total * ratio)
|
130 |
+
# Sort all lambda values globally with stable sorting
|
131 |
+
sorted_lambs, sorted_indices = torch.sort(all_lambs, descending=True, stable=True)
|
132 |
+
# Find the global threshold value
|
133 |
+
threshold = sorted_lambs[num_activate - 1]
|
134 |
+
# Create a global mask based on the sorted indices
|
135 |
+
global_mask = torch.zeros_like(all_lambs)
|
136 |
+
global_mask[sorted_indices[:num_activate]] = 1.0
|
137 |
+
# Binarize all lambdas based on the global threshold and mask
|
138 |
+
start_idx = 0
|
139 |
+
for i in range(len(self.lambs)):
|
140 |
+
end_idx = start_idx + self.lambs[i].numel()
|
141 |
+
lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
|
142 |
+
self.lambs[i] = torch.where(self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask)
|
143 |
+
start_idx = end_idx
|
144 |
+
self.binary = True
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def masking_fn(hidden_states, **kwargs):
|
148 |
+
hidden_states_dtype = hidden_states.dtype
|
149 |
+
lamb = kwargs["lamb"].view(1, 1, kwargs["lamb"].shape[0])
|
150 |
+
if kwargs.get("masking", None) == "sigmoid":
|
151 |
+
mask = torch.sigmoid(lamb)
|
152 |
+
elif kwargs.get("masking", None) == "binary":
|
153 |
+
mask = lamb
|
154 |
+
elif kwargs.get("masking", None) == "continues2binary":
|
155 |
+
# TODO: this might cause potential issue as it hard threshold at 0
|
156 |
+
mask = (lamb > 0).float()
|
157 |
+
elif kwargs.get("masking", None) == "no_masking":
|
158 |
+
mask = torch.ones_like(lamb)
|
159 |
+
else:
|
160 |
+
raise NotImplementedError
|
161 |
+
epsilon = kwargs.get("epsilon", 0.0)
|
162 |
+
hidden_states = hidden_states * mask + torch.randn_like(hidden_states) * epsilon * (1 - mask)
|
163 |
+
return hidden_states.to(hidden_states_dtype)
|
164 |
+
|
165 |
+
def get_ff_masking_hook(self, init_value=1.0):
|
166 |
+
"""
|
167 |
+
Get a hook function to mask feed forward layer
|
168 |
+
"""
|
169 |
+
|
170 |
+
def hook_fn(module, args, kwargs, output, name):
|
171 |
+
# initialize lambda with acual head dim in the first run
|
172 |
+
if self.lambs[self.hook_counter] is None:
|
173 |
+
self.lambs[self.hook_counter] = (
|
174 |
+
torch.ones(self.module_neurons[name], device=self.pipeline.device, dtype=self.dtype) * init_value
|
175 |
+
)
|
176 |
+
self.lambs[self.hook_counter].requires_grad = True
|
177 |
+
# load ff lambda module name for logging
|
178 |
+
self.lambs_module_names[self.hook_counter] = name
|
179 |
+
|
180 |
+
# perform masking
|
181 |
+
output = self.masking_fn(
|
182 |
+
output,
|
183 |
+
masking=self.masking,
|
184 |
+
lamb=self.lambs[self.hook_counter],
|
185 |
+
epsilon=self.epsilon,
|
186 |
+
eps=self.eps,
|
187 |
+
use_log=self.use_log,
|
188 |
+
)
|
189 |
+
self.hook_counter += 1
|
190 |
+
self.hook_counter %= len(self.lambs)
|
191 |
+
return output
|
192 |
+
|
193 |
+
return hook_fn
|
src/utils.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from copy import deepcopy
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from diffusers.models.activations import GEGLU, GELU
|
7 |
+
from src.cross_attn_hook import CrossAttentionExtractionHook
|
8 |
+
from src.ffn_hooker import FeedForwardHooker
|
9 |
+
|
10 |
+
|
11 |
+
# create dummy module for skip connection
|
12 |
+
class SkipConnection(torch.nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super(SkipConnection, self).__init__()
|
15 |
+
|
16 |
+
def forward(*args, **kwargs):
|
17 |
+
return args[1]
|
18 |
+
|
19 |
+
|
20 |
+
def calculate_mask_sparsity(hooker, threshold: Optional[float] = None):
|
21 |
+
total_num_lambs = 0
|
22 |
+
num_activate_lambs = 0
|
23 |
+
binary = getattr(hooker, "binary", None) # if binary is not present, it will return None for ff_hooks
|
24 |
+
for lamb in hooker.lambs:
|
25 |
+
total_num_lambs += lamb.size(0)
|
26 |
+
if binary:
|
27 |
+
assert threshold is None, "threshold should be None for binary mask"
|
28 |
+
num_activate_lambs += lamb.sum().item()
|
29 |
+
else:
|
30 |
+
assert threshold is not None, "threshold must be provided for non-binary mask"
|
31 |
+
num_activate_lambs += (lamb >= threshold).sum().item()
|
32 |
+
return total_num_lambs, num_activate_lambs, num_activate_lambs / total_num_lambs
|
33 |
+
|
34 |
+
|
35 |
+
def create_pipeline(
|
36 |
+
pipe,
|
37 |
+
model_id,
|
38 |
+
device,
|
39 |
+
torch_dtype,
|
40 |
+
save_pt=None,
|
41 |
+
lambda_threshold: float = 1,
|
42 |
+
binary=True,
|
43 |
+
epsilon=0.0,
|
44 |
+
masking="binary",
|
45 |
+
attn_name="attn",
|
46 |
+
return_hooker=False,
|
47 |
+
scope=None,
|
48 |
+
ratio=None,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
create the pipeline and optionally load the saved mask
|
52 |
+
"""
|
53 |
+
pipe.to(device)
|
54 |
+
pipe.vae.requires_grad_(False)
|
55 |
+
if hasattr(pipe, "unet"):
|
56 |
+
pipe.unet.requires_grad_(False)
|
57 |
+
else:
|
58 |
+
pipe.transformer.requires_grad_(False)
|
59 |
+
if save_pt:
|
60 |
+
# TODO should merge all the hooks checkpoint into one
|
61 |
+
if "ff.pt" in save_pt or "attn.pt" in save_pt:
|
62 |
+
save_pts = get_save_pts(save_pt)
|
63 |
+
|
64 |
+
cross_attn_hooker = CrossAttentionExtractionHook(
|
65 |
+
pipe,
|
66 |
+
model_name=model_id,
|
67 |
+
regex=".*",
|
68 |
+
dtype=torch_dtype,
|
69 |
+
head_num_filter=1,
|
70 |
+
masking=masking, # need to change to binary during inference
|
71 |
+
dst=save_pts["attn"],
|
72 |
+
epsilon=epsilon,
|
73 |
+
attn_name=attn_name,
|
74 |
+
binary=binary,
|
75 |
+
)
|
76 |
+
cross_attn_hooker.add_hooks(init_value=1)
|
77 |
+
ff_hooker = FeedForwardHooker(
|
78 |
+
pipe,
|
79 |
+
regex=".*",
|
80 |
+
dtype=torch_dtype,
|
81 |
+
masking=masking,
|
82 |
+
dst=save_pts["ff"],
|
83 |
+
epsilon=epsilon,
|
84 |
+
binary=binary,
|
85 |
+
)
|
86 |
+
ff_hooker.add_hooks(init_value=1)
|
87 |
+
norm_hooker = None
|
88 |
+
|
89 |
+
g_cpu = torch.Generator(torch.device(device)).manual_seed(1)
|
90 |
+
_ = pipe("abc", generator=g_cpu, num_inference_steps=1)
|
91 |
+
cross_attn_hooker.load(device=device, threshold=lambda_threshold)
|
92 |
+
ff_hooker.load(device=device, threshold=lambda_threshold)
|
93 |
+
if norm_hooker:
|
94 |
+
norm_hooker.load(device=device, threshold=lambda_threshold)
|
95 |
+
if scope == "local" or scope == "global":
|
96 |
+
if isinstance(ratio, float):
|
97 |
+
attn_hooker_ratio = ratio
|
98 |
+
ff_hooker_ratio = ratio
|
99 |
+
else:
|
100 |
+
attn_hooker_ratio, ff_hooker_ratio = ratio[0], ratio[1]
|
101 |
+
|
102 |
+
if norm_hooker:
|
103 |
+
if len(ratio) < 3:
|
104 |
+
raise ValueError("Need to provide ratio for norm layer")
|
105 |
+
norm_hooker_ratio = ratio[2]
|
106 |
+
|
107 |
+
cross_attn_hooker.binarize(scope, attn_hooker_ratio)
|
108 |
+
ff_hooker.binarize(scope, ff_hooker_ratio)
|
109 |
+
if norm_hooker:
|
110 |
+
norm_hooker.binarize(scope, norm_hooker_ratio)
|
111 |
+
hookers = [cross_attn_hooker, ff_hooker]
|
112 |
+
if norm_hooker:
|
113 |
+
hookers.append(norm_hooker)
|
114 |
+
|
115 |
+
if return_hooker:
|
116 |
+
return pipe, hookers
|
117 |
+
else:
|
118 |
+
return pipe
|
119 |
+
|
120 |
+
|
121 |
+
def linear_layer_pruning(module, lamb):
|
122 |
+
heads_to_keep = torch.nonzero(lamb).squeeze()
|
123 |
+
if len(heads_to_keep.shape) == 0:
|
124 |
+
# if only one head is kept, or none
|
125 |
+
heads_to_keep = heads_to_keep.unsqueeze(0)
|
126 |
+
|
127 |
+
modules_to_remove = [module.to_k, module.to_q, module.to_v]
|
128 |
+
new_heads = int(lamb.sum().item())
|
129 |
+
|
130 |
+
if new_heads == 0:
|
131 |
+
return SkipConnection()
|
132 |
+
|
133 |
+
for module_to_remove in modules_to_remove:
|
134 |
+
# get head dimension
|
135 |
+
inner_dim = module_to_remove.out_features // module.heads
|
136 |
+
# place holder for the rows to keep
|
137 |
+
rows_to_keep = torch.zeros(
|
138 |
+
module_to_remove.out_features, dtype=torch.bool, device=module_to_remove.weight.device
|
139 |
+
)
|
140 |
+
|
141 |
+
for idx in heads_to_keep:
|
142 |
+
rows_to_keep[idx * inner_dim : (idx + 1) * inner_dim] = True
|
143 |
+
|
144 |
+
# overwrite the inner projection with masked projection
|
145 |
+
module_to_remove.weight.data = module_to_remove.weight.data[rows_to_keep, :]
|
146 |
+
if module_to_remove.bias is not None:
|
147 |
+
module_to_remove.bias.data = module_to_remove.bias.data[rows_to_keep]
|
148 |
+
module_to_remove.out_features = int(sum(rows_to_keep).item())
|
149 |
+
|
150 |
+
# Also update the output projection layer if available, (for FLUXSingleAttnProcessor2_0)
|
151 |
+
# with column masking, dim 1
|
152 |
+
if getattr(module, "to_out", None) is not None:
|
153 |
+
module.to_out[0].weight.data = module.to_out[0].weight.data[:, rows_to_keep]
|
154 |
+
module.to_out[0].in_features = int(sum(rows_to_keep).item())
|
155 |
+
|
156 |
+
# update parameters in the attention module
|
157 |
+
module.inner_dim = module.inner_dim // module.heads * new_heads
|
158 |
+
try:
|
159 |
+
module.query_dim = module.query_dim // module.heads * new_heads
|
160 |
+
module.inner_kv_dim = module.inner_kv_dim // module.heads * new_heads
|
161 |
+
except:
|
162 |
+
pass
|
163 |
+
module.cross_attention_dim = module.cross_attention_dim // module.heads * new_heads
|
164 |
+
module.heads = new_heads
|
165 |
+
return module
|
166 |
+
|
167 |
+
|
168 |
+
def ffn_linear_layer_pruning(module, lamb):
|
169 |
+
lambda_to_keep = torch.nonzero(lamb).squeeze()
|
170 |
+
if len(lambda_to_keep) == 0:
|
171 |
+
return SkipConnection()
|
172 |
+
|
173 |
+
num_lambda = len(lambda_to_keep)
|
174 |
+
|
175 |
+
if isinstance(module.net[0], GELU):
|
176 |
+
# linear layer weight remove before activation
|
177 |
+
module.net[0].proj.weight.data = module.net[0].proj.weight.data[lambda_to_keep, :]
|
178 |
+
module.net[0].proj.out_features = num_lambda
|
179 |
+
if module.net[0].proj.bias is not None:
|
180 |
+
module.net[0].proj.bias.data = module.net[0].proj.bias.data[lambda_to_keep]
|
181 |
+
|
182 |
+
update_act = GELU(module.net[0].proj.in_features, num_lambda)
|
183 |
+
update_act.proj = module.net[0].proj
|
184 |
+
module.net[0] = update_act
|
185 |
+
elif isinstance(module.net[0], GEGLU):
|
186 |
+
output_feature = module.net[0].proj.out_features
|
187 |
+
module.net[0].proj.weight.data = torch.cat(
|
188 |
+
[
|
189 |
+
module.net[0].proj.weight.data[: output_feature // 2, :][lambda_to_keep, :],
|
190 |
+
module.net[0].proj.weight.data[output_feature // 2 :][lambda_to_keep, :],
|
191 |
+
],
|
192 |
+
dim=0,
|
193 |
+
)
|
194 |
+
module.net[0].proj.out_features = num_lambda * 2
|
195 |
+
if module.net[0].proj.bias is not None:
|
196 |
+
module.net[0].proj.bias.data = torch.cat(
|
197 |
+
[
|
198 |
+
module.net[0].proj.bias.data[: output_feature // 2][lambda_to_keep],
|
199 |
+
module.net[0].proj.bias.data[output_feature // 2 :][lambda_to_keep],
|
200 |
+
]
|
201 |
+
)
|
202 |
+
|
203 |
+
update_act = GEGLU(module.net[0].proj.in_features, num_lambda * 2)
|
204 |
+
update_act.proj = module.net[0].proj
|
205 |
+
module.net[0] = update_act
|
206 |
+
|
207 |
+
# proj weight after activation
|
208 |
+
module.net[2].weight.data = module.net[2].weight.data[:, lambda_to_keep]
|
209 |
+
module.net[2].in_features = num_lambda
|
210 |
+
|
211 |
+
return module
|
212 |
+
|
213 |
+
|
214 |
+
def get_save_pts(save_pt):
|
215 |
+
if "ff.pt" in save_pt:
|
216 |
+
ff_save_pt = deepcopy(save_pt) # avoid in-place operation
|
217 |
+
attn_save_pt = save_pt.split(os.sep)
|
218 |
+
attn_save_pt[-1] = attn_save_pt[-1].replace("ff", "attn")
|
219 |
+
attn_save_pt_output = os.sep.join(attn_save_pt)
|
220 |
+
attn_save_pt[-1] = attn_save_pt[-1].replace("attn", "norm")
|
221 |
+
norm_save_pt = os.sep.join(attn_save_pt)
|
222 |
+
|
223 |
+
return {
|
224 |
+
"ff": ff_save_pt,
|
225 |
+
"attn": attn_save_pt_output,
|
226 |
+
"norm": norm_save_pt,
|
227 |
+
}
|
228 |
+
else:
|
229 |
+
attn_save_pt = deepcopy(save_pt)
|
230 |
+
ff_save_pt = save_pt.split(os.sep)
|
231 |
+
ff_save_pt[-1] = ff_save_pt[-1].replace("attn", "ff")
|
232 |
+
ff_save_pt_output = os.sep.join(ff_save_pt)
|
233 |
+
ff_save_pt[-1] = ff_save_pt[-1].replace("ff", "norm")
|
234 |
+
norm_save_pt = os.sep.join(attn_save_pt)
|
235 |
+
|
236 |
+
return {
|
237 |
+
"ff": ff_save_pt_output,
|
238 |
+
"attn": attn_save_pt,
|
239 |
+
"norm": norm_save_pt,
|
240 |
+
}
|
241 |
+
|
242 |
+
|
243 |
+
def save_img(pipe, g_cpu, steps, prompt, save_path):
|
244 |
+
image = pipe(prompt, generator=g_cpu, num_inference_steps=steps)
|
245 |
+
image["images"][0].save(save_path)
|