import gradio as gr from dataclasses import dataclass import spaces import torch from huggingface_hub import hf_hub_download from diffusers import StableDiffusionXLPipeline, FluxPipeline device = "cuda" if torch.cuda.is_available() else "cpu" @dataclass class GradioArgs: seed: list = None prompt: str = None mix_precision: str = "bf16" num_intervention_steps: int = 50 model: str = "sdxl" binary: bool = False masking: str = "binary" scope: str = "global" ratio: list = None width: int = None height: int = None epsilon: float = 0.0 lambda_threshold: float = 0.001 def __post_init__(self): if self.seed is None: self.seed = [44] def binary_mask_eval(args, model): model = model.lower() # load sdxl model if model == "sdxl": pruned_pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 ).to("cpu") pruned_pipe.unet = torch.load( hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"), map_location="cpu", ) elif model == "flux": pruned_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to( "cpu" ) pruned_pipe.transformer = torch.load( hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"), map_location="cpu", ) # reload the original model if model == "sdxl": pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16 ).to("cpu") elif model == "flux": pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cpu") print("prune complete") return pipe, pruned_pipe @spaces.GPU def generate_images(prompt, seed, steps, pipe, pruned_pipe): pipe.to("cuda") pruned_pipe.to("cuda") # Run the model and return images directly g_cpu = torch.Generator("cuda").manual_seed(seed) original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] g_cpu = torch.Generator("cuda").manual_seed(seed) ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] return original_image, ecodiff_image def on_prune_click(prompt, seed, steps, model): args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps) pipe, pruned_pipe = binary_mask_eval(args, model) return pipe, pruned_pipe, [("Model Initialized", "green")] def on_generate_click(prompt, seed, steps, pipe, pruned_pipe): original_image, ecodiff_image = generate_images(prompt, seed, steps, pipe, pruned_pipe) return original_image, ecodiff_image header = """ # 🌱 Text-to-Image Generation with EcoDiff Pruned Models """ header_2 = """
""" header_3 = """ """ def create_demo(): with gr.Blocks() as demo: with gr.Row(): gr.Markdown(header) with gr.Row(): gr.HTML(header_2) with gr.Row(): gr.HTML(header_3) with gr.Row(): gr.Markdown( """ **Note: Please first initialize the model before generating images. This may take a while to fully load.** """ ) with gr.Row(): model_choice = gr.Radio(choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2) pruning_ratio = gr.Text("20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2) status_label = gr.HighlightedText(label="Model Status", value=[("Model Not Initialized", "red")], scale=1) prune_btn = gr.Button("Initialize Original and Pruned Models", variant="primary", scale=1) with gr.Row(): gr.Markdown( """ **Generate images with the original model and the pruned model. May take up to 1 minute due to dynamic allocation of GPU.** **Note: we prune on step-distilled FLUX, you should use step 5 (instead of 50) for FLUX generation.** """ ) with gr.Row(): prompt = gr.Textbox( label="Prompt", value="A clock tower floating in a sea of clouds", scale=3, ) seed = gr.Number(label="Seed", value=44, precision=0, scale=1) steps = gr.Slider( label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1, ) generate_btn = gr.Button("Generate Images") gr.Examples( examples=[ "A clock tower floating in a sea of clouds", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", ], inputs=[prompt], ) with gr.Row(): original_output = gr.Image(label="Original Output") ecodiff_output = gr.Image(label="EcoDiff Output") pipe_state = gr.State(None) pruned_pipe_state = gr.State(None) prompt.submit( fn=on_generate_click, inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], outputs=[original_output, ecodiff_output], ) prune_btn.click( fn=on_prune_click, inputs=[prompt, seed, steps, model_choice], outputs=[pipe_state, pruned_pipe_state, status_label], ) generate_btn.click( fn=on_generate_click, inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state], outputs=[original_output, ecodiff_output], ) return demo if __name__ == "__main__": demo = create_demo() demo.launch(share=True)