import gradio as gr import numpy as np from huggingface_hub import hf_hub_download import spaces # [uncomment to use ZeroGPU] from diffusers import DiffusionPipeline import torch device = "cuda" if torch.cuda.is_available() else "cpu" model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0" # Replace to the model you would like to use torch_dtype = torch.bfloat16 pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) pipe = pipe.to(device) # load pruned model pruned_pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) pruned_pipe.transformer = torch.load( hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"), map_location="cpu", ) pruned_pipe = pruned_pipe.to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 @spaces.GPU # [uncomment to use ZeroGPU] def generate_images(prompt, seed, steps, pipe, pruned_pipe): # Run the model and return images directly g_cpu = torch.Generator("cuda").manual_seed(seed) original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] g_cpu = torch.Generator("cuda").manual_seed(seed) ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] return original_image, ecodiff_image examples = [ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", ] css = """ #col-container { margin: 0 auto; max-width: 640px; } """ header = """ # 🌱 Text-to-Image Generation with EcoDiff Pruned SD-XL (20% Pruning Ratio) # Under Construction!!!
arXiv HuggingFace GitHub
""" with gr.Blocks(css=css) as demo: gr.Markdown(header) with gr.Row(): prompt = gr.Textbox( label="Prompt", value="A clock tower floating in a sea of clouds", scale=3, ) seed = gr.Number(label="Seed", value=44, precision=0, scale=1) steps = gr.Slider( label="Number of Steps", minimum=1, maximum=100, value=50, step=1, scale=1, ) generate_btn = gr.Button("Generate Images") gr.Examples( examples=[ "A clock tower floating in a sea of clouds", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", ], inputs=[prompt], ) with gr.Row(): original_output = gr.Image(label="Original Output") ecodiff_output = gr.Image(label="EcoDiff Output") gr.on( triggers=[generate_btn.click, prompt.submit], fn=generate_images, inputs=[ prompt, seed, steps, pipe, pipe, ], outputs=[original_output, ecodiff_output], ) if __name__ == "__main__": demo.launch()