Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,354 Bytes
49f5a92 82d824b 49f5a92 82d824b 49f5a92 82d824b 49f5a92 82d824b 49f5a92 82d824b 49f5a92 82d824b 49f5a92 82d824b 49f5a92 82d824b 49f5a92 82d824b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import gradio as gr
from dataclasses import dataclass
import torch
from tqdm import tqdm
from src.utils import (
create_pipeline,
calculate_mask_sparsity,
ffn_linear_layer_pruning,
linear_layer_pruning,
)
from diffusers import StableDiffusionXLPipeline
def get_model_param_summary(model, verbose=False):
params_dict = dict()
overall_params = 0
for name, params in model.named_parameters():
num_params = params.numel()
overall_params += num_params
if verbose:
print(f"GPU Memory Requirement for {name}: {params} MiB")
params_dict.update({name: num_params})
params_dict.update({"overall": overall_params})
return params_dict
@dataclass
class GradioArgs:
ckpt: str = "./mask/ff.pt"
device: str = "cuda:0"
seed: list = None
prompt: str = None
mix_precision: str = "bf16"
num_intervention_steps: int = 50
model: str = "sdxl"
binary: bool = False
masking: str = "binary"
scope: str = "global"
ratio: list = None
width: int = None
height: int = None
epsilon: float = 0.0
lambda_threshold: float = 0.001
def __post_init__(self):
if self.seed is None:
self.seed = [44]
if self.ratio is None:
self.ratio = [0.68, 0.88]
def prune_model(pipe, hookers):
# remove parameters in attention blocks
cross_attn_hooker = hookers[0]
for name in tqdm(cross_attn_hooker.hook_dict.keys(), desc="Pruning attention layers"):
if getattr(pipe, "unet", None):
module = pipe.unet.get_submodule(name)
else:
module = pipe.transformer.get_submodule(name)
lamb = cross_attn_hooker.lambs[cross_attn_hooker.lambs_module_names.index(name)]
assert module.heads == lamb.shape[0]
module = linear_layer_pruning(module, lamb)
parent_module_name, child_name = name.rsplit(".", 1)
if getattr(pipe, "unet", None):
parent_module = pipe.unet.get_submodule(parent_module_name)
else:
parent_module = pipe.transformer.get_submodule(parent_module_name)
setattr(parent_module, child_name, module)
# remove parameters in ffn blocks
ffn_hook = hookers[1]
for name in tqdm(ffn_hook.hook_dict.keys(), desc="Pruning on FFN linear lazer"):
if getattr(pipe, "unet", None):
module = pipe.unet.get_submodule(name)
else:
module = pipe.transformer.get_submodule(name)
lamb = ffn_hook.lambs[ffn_hook.lambs_module_names.index(name)]
module = ffn_linear_layer_pruning(module, lamb)
parent_module_name, child_name = name.rsplit(".", 1)
if getattr(pipe, "unet", None):
parent_module = pipe.unet.get_submodule(parent_module_name)
else:
parent_module = pipe.transformer.get_submodule(parent_module_name)
setattr(parent_module, child_name, module)
cross_attn_hooker.clear_hooks()
ffn_hook.clear_hooks()
return pipe
def binary_mask_eval(args):
# load sdxl model
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to(args.device)
device = args.device
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
mask_pipe, hookers = create_pipeline(
pipe,
args.model,
device,
torch_dtype,
args.ckpt,
binary=args.binary,
lambda_threshold=args.lambda_threshold,
epsilon=args.epsilon,
masking=args.masking,
return_hooker=True,
scope=args.scope,
ratio=args.ratio,
)
# Print mask sparsity info
threshold = None if args.binary else args.lambda_threshold
threshold = None if args.scope is not None else threshold
name = ["ff", "attn"]
for n, hooker in zip(name, hookers):
total_num_heads, num_activate_heads, mask_sparsity = calculate_mask_sparsity(hooker, threshold)
print(f"model: {args.model}, {n} masking: {args.masking}")
print(
f"total num heads: {total_num_heads},"
+ f"num activate heads: {num_activate_heads}, mask sparsity: {mask_sparsity}"
)
# Prune the model
pruned_pipe = prune_model(mask_pipe, hookers)
# reload the original model
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to(args.device)
# get model param summary
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
print(f"pruned model param: {get_model_param_summary(pruned_pipe.unet)['overall']}")
print("prune complete")
return pipe, pruned_pipe
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
# Run the model and return images directly
g_cpu = torch.Generator("cuda:0").manual_seed(seed)
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
g_cpu = torch.Generator("cuda:0").manual_seed(seed)
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
return original_image, ecodiff_image
def on_prune_click(prompt, seed, steps):
args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
pipe, pruned_pipe = binary_mask_eval(args)
return pipe, pruned_pipe, [("Model Initialized", "green")]
def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
return original_image, ecodiff_image
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
with gr.Row():
gr.Markdown(
"""
# 🚧 Under Construction 🚧
This demo is currently being developed and may not be fully functional. More models and pruning ratios will be supported soon.
The current pruned model checkpoint is not optimal and does not provide the best performance.
**Note: Please first initialize the model before generating images.**
"""
)
with gr.Row():
model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
with gr.Row():
prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1)
generate_btn = gr.Button("Generate Images")
gr.Examples(
examples=[
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
],
inputs=[prompt],
)
with gr.Row():
original_output = gr.Image(label="Original Output")
ecodiff_output = gr.Image(label="EcoDiff Output")
pipe_state = gr.State(None)
pruned_pipe_state = gr.State(None)
prompt.submit(
fn=on_generate_click,
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
outputs=[original_output, ecodiff_output],
)
prune_btn.click(
fn=on_prune_click,
inputs=[prompt, seed, steps],
outputs=[pipe_state, pruned_pipe_state, status_label],
)
generate_btn.click(
fn=on_generate_click,
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
outputs=[original_output, ecodiff_output],
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True)
|