File size: 8,354 Bytes
49f5a92
82d824b
49f5a92
 
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f5a92
82d824b
 
 
 
 
 
 
 
 
 
 
49f5a92
82d824b
 
49f5a92
82d824b
 
 
 
49f5a92
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f5a92
 
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f5a92
82d824b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49f5a92
 
 
82d824b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import gradio as gr
from dataclasses import dataclass

import torch
from tqdm import tqdm

from src.utils import (
    create_pipeline,
    calculate_mask_sparsity,
    ffn_linear_layer_pruning,
    linear_layer_pruning,
)
from diffusers import StableDiffusionXLPipeline


def get_model_param_summary(model, verbose=False):
    params_dict = dict()
    overall_params = 0
    for name, params in model.named_parameters():
        num_params = params.numel()
        overall_params += num_params
        if verbose:
            print(f"GPU Memory Requirement for {name}: {params} MiB")
        params_dict.update({name: num_params})
    params_dict.update({"overall": overall_params})
    return params_dict


@dataclass
class GradioArgs:
    ckpt: str = "./mask/ff.pt"
    device: str = "cuda:0"
    seed: list = None
    prompt: str = None
    mix_precision: str = "bf16"
    num_intervention_steps: int = 50
    model: str = "sdxl"
    binary: bool = False
    masking: str = "binary"
    scope: str = "global"
    ratio: list = None
    width: int = None
    height: int = None
    epsilon: float = 0.0
    lambda_threshold: float = 0.001

    def __post_init__(self):
        if self.seed is None:
            self.seed = [44]
        if self.ratio is None:
            self.ratio = [0.68, 0.88]


def prune_model(pipe, hookers):
    # remove parameters in attention blocks
    cross_attn_hooker = hookers[0]
    for name in tqdm(cross_attn_hooker.hook_dict.keys(), desc="Pruning attention layers"):
        if getattr(pipe, "unet", None):
            module = pipe.unet.get_submodule(name)
        else:
            module = pipe.transformer.get_submodule(name)
        lamb = cross_attn_hooker.lambs[cross_attn_hooker.lambs_module_names.index(name)]
        assert module.heads == lamb.shape[0]
        module = linear_layer_pruning(module, lamb)

        parent_module_name, child_name = name.rsplit(".", 1)
        if getattr(pipe, "unet", None):
            parent_module = pipe.unet.get_submodule(parent_module_name)
        else:
            parent_module = pipe.transformer.get_submodule(parent_module_name)
        setattr(parent_module, child_name, module)

    # remove parameters in ffn blocks
    ffn_hook = hookers[1]
    for name in tqdm(ffn_hook.hook_dict.keys(), desc="Pruning on FFN linear lazer"):
        if getattr(pipe, "unet", None):
            module = pipe.unet.get_submodule(name)
        else:
            module = pipe.transformer.get_submodule(name)
        lamb = ffn_hook.lambs[ffn_hook.lambs_module_names.index(name)]
        module = ffn_linear_layer_pruning(module, lamb)

        parent_module_name, child_name = name.rsplit(".", 1)
        if getattr(pipe, "unet", None):
            parent_module = pipe.unet.get_submodule(parent_module_name)
        else:
            parent_module = pipe.transformer.get_submodule(parent_module_name)
        setattr(parent_module, child_name, module)

    cross_attn_hooker.clear_hooks()
    ffn_hook.clear_hooks()
    return pipe


def binary_mask_eval(args):
    # load sdxl model
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
    ).to(args.device)

    device = args.device
    torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
    mask_pipe, hookers = create_pipeline(
        pipe,
        args.model,
        device,
        torch_dtype,
        args.ckpt,
        binary=args.binary,
        lambda_threshold=args.lambda_threshold,
        epsilon=args.epsilon,
        masking=args.masking,
        return_hooker=True,
        scope=args.scope,
        ratio=args.ratio,
    )

    # Print mask sparsity info
    threshold = None if args.binary else args.lambda_threshold
    threshold = None if args.scope is not None else threshold
    name = ["ff", "attn"]
    for n, hooker in zip(name, hookers):
        total_num_heads, num_activate_heads, mask_sparsity = calculate_mask_sparsity(hooker, threshold)
        print(f"model: {args.model}, {n} masking: {args.masking}")
        print(
            f"total num heads: {total_num_heads},"
            + f"num activate heads: {num_activate_heads}, mask sparsity: {mask_sparsity}"
        )

    # Prune the model
    pruned_pipe = prune_model(mask_pipe, hookers)

    # reload the original model
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
    ).to(args.device)

    # get model param summary
    print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
    print(f"pruned model param: {get_model_param_summary(pruned_pipe.unet)['overall']}")
    print("prune complete")
    return pipe, pruned_pipe


def generate_images(prompt, seed, steps, pipe, pruned_pipe):
    # Run the model and return images directly
    g_cpu = torch.Generator("cuda:0").manual_seed(seed)
    original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    g_cpu = torch.Generator("cuda:0").manual_seed(seed)
    ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    return original_image, ecodiff_image


def on_prune_click(prompt, seed, steps):
    args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
    pipe, pruned_pipe = binary_mask_eval(args)
    return pipe, pruned_pipe, [("Model Initialized", "green")]


def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
    original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
    return original_image, ecodiff_image


def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
        with gr.Row():
            gr.Markdown(
                """
                # 🚧 Under Construction 🚧
                This demo is currently being developed and may not be fully functional. More models and pruning ratios will be supported soon.
                The current pruned model checkpoint is not optimal and does not provide the best performance.
                
                **Note: Please first initialize the model before generating images.**
                """
            )
        with gr.Row():
            model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
            pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
            prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
            status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
        with gr.Row():
            prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
            seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
            steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1)
            generate_btn = gr.Button("Generate Images")
        gr.Examples(
            examples=[
                "A clock tower floating in a sea of clouds",
                "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
                "An astronaut riding a green horse",
                "A delicious ceviche cheesecake slice",
            ],
            inputs=[prompt],
        )
        with gr.Row():
            original_output = gr.Image(label="Original Output")
            ecodiff_output = gr.Image(label="EcoDiff Output")

        pipe_state = gr.State(None)
        pruned_pipe_state = gr.State(None)
        prompt.submit(
            fn=on_generate_click,
            inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
            outputs=[original_output, ecodiff_output],
        )
        prune_btn.click(
            fn=on_prune_click,
            inputs=[prompt, seed, steps],
            outputs=[pipe_state, pruned_pipe_state, status_label],
        )
        generate_btn.click(
            fn=on_generate_click,
            inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
            outputs=[original_output, ecodiff_output],
        )

    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.launch(share=True)