import gradio as gr from dataclasses import dataclass import torch from tqdm import tqdm from src.utils import ( create_pipeline, calculate_mask_sparsity, ffn_linear_layer_pruning, linear_layer_pruning, ) from diffusers import StableDiffusionXLPipeline def get_model_param_summary(model, verbose=False): params_dict = dict() overall_params = 0 for name, params in model.named_parameters(): num_params = params.numel() overall_params += num_params if verbose: print(f"GPU Memory Requirement for {name}: {params} MiB") params_dict.update({name: num_params}) params_dict.update({"overall": overall_params}) return params_dict @dataclass class GradioArgs: ckpt: str = "./mask/ff.pt" device: str = "cuda:0" seed: list = None prompt: str = None mix_precision: str = "bf16" num_intervention_steps: int = 50 model: str = "sdxl" binary: bool = False masking: str = "binary" scope: str = "global" ratio: list = None width: int = None height: int = None epsilon: float = 0.0 lambda_threshold: float = 0.001 def __post_init__(self): if self.seed is None: self.seed = [44] if self.ratio is None: self.ratio = [0.68, 0.88] def prune_model(pipe, hookers): # remove parameters in attention blocks cross_attn_hooker = hookers[0] for name in tqdm(cross_attn_hooker.hook_dict.keys(), desc="Pruning attention layers"): if getattr(pipe, "unet", None): module = pipe.unet.get_submodule(name) else: module = pipe.transformer.get_submodule(name) lamb = cross_attn_hooker.lambs[cross_attn_hooker.lambs_module_names.index(name)] assert module.heads == lamb.shape[0] module = linear_layer_pruning(module, lamb) parent_module_name, child_name = name.rsplit(".", 1) if getattr(pipe, "unet", None): parent_module = pipe.unet.get_submodule(parent_module_name) else: parent_module = pipe.transformer.get_submodule(parent_module_name) setattr(parent_module, child_name, module) # remove parameters in ffn blocks ffn_hook = hookers[1] for name in tqdm(ffn_hook.hook_dict.keys(), desc="Pruning on FFN linear lazer"): if getattr(pipe, "unet", None): module = pipe.unet.get_submodule(name) else: module = pipe.transformer.get_submodule(name) lamb = ffn_hook.lambs[ffn_hook.lambs_module_names.index(name)] module = ffn_linear_layer_pruning(module, lamb) parent_module_name, child_name = name.rsplit(".", 1) if getattr(pipe, "unet", None): parent_module = pipe.unet.get_submodule(parent_module_name) else: parent_module = pipe.transformer.get_submodule(parent_module_name) setattr(parent_module, child_name, module) cross_attn_hooker.clear_hooks() ffn_hook.clear_hooks() return pipe def binary_mask_eval(args): # load sdxl model pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 ).to(args.device) device = args.device torch_dtype = torch.bfloat16 if args.mix_precision == "bf16" else torch.float32 mask_pipe, hookers = create_pipeline( pipe, args.model, device, torch_dtype, args.ckpt, binary=args.binary, lambda_threshold=args.lambda_threshold, epsilon=args.epsilon, masking=args.masking, return_hooker=True, scope=args.scope, ratio=args.ratio, ) # Print mask sparsity info threshold = None if args.binary else args.lambda_threshold threshold = None if args.scope is not None else threshold name = ["ff", "attn"] for n, hooker in zip(name, hookers): total_num_heads, num_activate_heads, mask_sparsity = calculate_mask_sparsity(hooker, threshold) print(f"model: {args.model}, {n} masking: {args.masking}") print( f"total num heads: {total_num_heads}," + f"num activate heads: {num_activate_heads}, mask sparsity: {mask_sparsity}" ) # Prune the model pruned_pipe = prune_model(mask_pipe, hookers) # reload the original model pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 ).to(args.device) # get model param summary print(f"original model param: {get_model_param_summary(pipe.unet)['overall']}") print(f"pruned model param: {get_model_param_summary(pruned_pipe.unet)['overall']}") print("prune complete") return pipe, pruned_pipe def generate_images(prompt, seed, steps, pipe, pruned_pipe): # Run the model and return images directly g_cpu = torch.Generator("cuda:0").manual_seed(seed) original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] g_cpu = torch.Generator("cuda:0").manual_seed(seed) ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] return original_image, ecodiff_image def on_prune_click(prompt, seed, steps): args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps) pipe, pruned_pipe = binary_mask_eval(args) return pipe, pruned_pipe, [("Model Initialized", "green")] def on_generate_click(prompt, seed, steps, pipe, pruned_pipe): original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe) return original_image, ecodiff_image def create_demo(): with gr.Blocks() as demo: gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model") with gr.Row(): gr.Markdown( """ # 🚧 Under Construction 🚧 This demo is currently being developed and may not be fully functional. More models and pruning ratios will be supported soon. The current pruned model checkpoint is not optimal and does not provide the best performance. **Note: Please first initialize the model before generating images.** """ ) with gr.Row(): model_choice = gr.Dropdown(choices=["SDXL"], value="SDXL", label="Model", scale=1.2) pruning_ratio = gr.Dropdown(choices=["20%"], value="20%", label="Pruning Ratio", scale=1.2) prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1) status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1) with gr.Row(): prompt = gr.Textbox(label="Prompt", value="A clock tower floating in a sea of clouds", scale=3) seed = gr.Number(label="Seed", value=44, precision=0, scale=1) steps = gr.Slider(label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1) generate_btn = gr.Button("Generate Images") gr.Examples( examples=[ "A clock tower floating in a sea of clouds", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", ], inputs=[prompt], ) with gr.Row(): original_output = gr.Image(label="Original Output") ecodiff_output = gr.Image(label="EcoDiff Output") pipe_state = gr.State(None) pruned_pipe_state = gr.State(None) prompt.submit( fn=on_generate_click, inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], outputs=[original_output, ecodiff_output], ) prune_btn.click( fn=on_prune_click, inputs=[prompt, seed, steps], outputs=[pipe_state, pruned_pipe_state, status_label], ) generate_btn.click( fn=on_generate_click, inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], outputs=[original_output, ecodiff_output], ) return demo if __name__ == "__main__": demo = create_demo() demo.launch(share=True)