EcoDiff / app.py
zhangyang-0123's picture
add ecodiff demo
82d824b
raw
history blame
8.35 kB
import gradio as gr
from dataclasses import dataclass
import torch
from tqdm import tqdm
from src.utils import (
create_pipeline,
calculate_mask_sparsity,
ffn_linear_layer_pruning,
linear_layer_pruning,
)
from diffusers import StableDiffusionXLPipeline
def get_model_param_summary(model, verbose=False):
params_dict = dict()
overall_params = 0
for name, params in model.named_parameters():
num_params = params.numel()
overall_params += num_params
if verbose:
print(f"GPU Memory Requirement for {name}: {params} MiB")
params_dict.update({name: num_params})
params_dict.update({"overall": overall_params})
return params_dict
@dataclass
class GradioArgs:
ckpt: str = "./mask/ff.pt"
device: str = "cuda:0"
seed: list = None
prompt: str = None
mix_precision: str = "bf16"
num_intervention_steps: int = 50
model: str = "sdxl"
binary: bool = False
masking: str = "binary"
scope: str = "global"
ratio: list = None
width: int = None
height: int = None
epsilon: float = 0.0
lambda_threshold: float = 0.001
def __post_init__(self):
if self.seed is None:
self.seed = [44]
if self.ratio is None:
self.ratio = [0.68, 0.88]
def prune_model(pipe, hookers):
# remove parameters in attention blocks
cross_attn_hooker = hookers[0]
for name in tqdm(cross_attn_hooker.hook_dict.keys(), desc="Pruning attention layers"):
if getattr(pipe, "unet", None):
module = pipe.unet.get_submodule(name)
else:
module = pipe.transformer.get_submodule(name)
lamb = cross_attn_hooker.lambs[cross_attn_hooker.lambs_module_names.index(name)]
assert module.heads == lamb.shape[0]
module = linear_layer_pruning(module, lamb)
parent_module_name, child_name = name.rsplit(".", 1)
if getattr(pipe, "unet", None):
parent_module = pipe.unet.get_submodule(parent_module_name)
else:
parent_module = pipe.transformer.get_submodule(parent_module_name)
setattr(parent_module, child_name, module)
# remove parameters in ffn blocks
ffn_hook = hookers[1]
for name in tqdm(ffn_hook.hook_dict.keys(), desc="Pruning on FFN linear lazer"):
if getattr(pipe, "unet", None):
module = pipe.unet.get_submodule(name)
else:
module = pipe.transformer.get_submodule(name)
lamb = ffn_hook.lambs[ffn_hook.lambs_module_names.index(name)]
module = ffn_linear_layer_pruning(module, lamb)
parent_module_name, child_name = name.rsplit(".", 1)
if getattr(pipe, "unet", None):
parent_module = pipe.unet.get_submodule(parent_module_name)
else:
parent_module = pipe.transformer.get_submodule(parent_module_name)
setattr(parent_module, child_name, module)
cross_attn_hooker.clear_hooks()
ffn_hook.clear_hooks()
return pipe
def binary_mask_eval(args):
# load sdxl model
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to(args.device)
device = args.device
torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32
mask_pipe, hookers = create_pipeline(
pipe,
args.model,
device,
torch_dtype,
args.ckpt,
binary=args.binary,
lambda_threshold=args.lambda_threshold,
epsilon=args.epsilon,
masking=args.masking,
return_hooker=True,
scope=args.scope,
ratio=args.ratio,
)
# Print mask sparsity info
threshold = None if args.binary else args.lambda_threshold
threshold = None if args.scope is not None else threshold
name = ["ff", "attn"]
for n, hooker in zip(name, hookers):
total_num_heads, num_activate_heads, mask_sparsity = calculate_mask_sparsity(hooker, threshold)
print(f"model: {args.model}, {n} masking: {args.masking}")
print(
f"total num heads: {total_num_heads},"
+ f"num activate heads: {num_activate_heads}, mask sparsity: {mask_sparsity}"
)
# Prune the model
pruned_pipe = prune_model(mask_pipe, hookers)
# reload the original model
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to(args.device)
# get model param summary
print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}")
print(f"pruned model param: {get_model_param_summary(pruned_pipe.unet)['overall']}")
print("prune complete")
return pipe, pruned_pipe
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
# Run the model and return images directly
g_cpu = torch.Generator("cuda:0").manual_seed(seed)
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
g_cpu = torch.Generator("cuda:0").manual_seed(seed)
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
return original_image, ecodiff_image
def on_prune_click(prompt, seed, steps):
args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
pipe, pruned_pipe = binary_mask_eval(args)
return pipe, pruned_pipe, [("Model Initialized", "green")]
def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe)
return original_image, ecodiff_image
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
with gr.Row():
gr.Markdown(
"""
# 🚧 Under Construction 🚧
This demo is currently being developed and may not be fully functional. More models and pruning ratios will be supported soon.
The current pruned model checkpoint is not optimal and does not provide the best performance.
**Note: Please first initialize the model before generating images.**
"""
)
with gr.Row():
model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2)
pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2)
prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1)
status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1)
with gr.Row():
prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3)
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1)
generate_btn = gr.Button("Generate Images")
gr.Examples(
examples=[
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
],
inputs=[prompt],
)
with gr.Row():
original_output = gr.Image(label="Original Output")
ecodiff_output = gr.Image(label="EcoDiff Output")
pipe_state = gr.State(None)
pruned_pipe_state = gr.State(None)
prompt.submit(
fn=on_generate_click,
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
outputs=[original_output, ecodiff_output],
)
prune_btn.click(
fn=on_prune_click,
inputs=[prompt, seed, steps],
outputs=[pipe_state, pruned_pipe_state, status_label],
)
generate_btn.click(
fn=on_generate_click,
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
outputs=[original_output, ecodiff_output],
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True)