Spaces:
Sleeping
Sleeping
File size: 3,759 Bytes
79fcfbd 78ea680 79fcfbd ee69d30 79fcfbd 78ea680 79fcfbd 78ea680 79fcfbd ee69d30 79fcfbd ee69d30 79fcfbd ee69d30 79fcfbd ee69d30 79fcfbd ee69d30 79fcfbd ee69d30 79fcfbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
import spaces # [uncomment to use ZeroGPU]
from diffusers import DiffusionPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0" # Replace to the model you would like to use
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
# load pruned model
pruned_pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pruned_pipe.transformer = torch.load(
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"),
map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU # [uncomment to use ZeroGPU]
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
# Run the model and return images directly
g_cpu = torch.Generator("cuda").manual_seed(seed)
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
g_cpu = torch.Generator("cuda").manual_seed(seed)
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
return original_image, ecodiff_image
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
header = """
# 🌱 Text-to-Image Generation with EcoDiff Pruned SD-XL (20% Pruning Ratio)
# Under Construction!!!
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
<a href="https://github.com/YaNgZhAnG-V5/EcoDiff"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a>
</div>
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(header)
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="A clock tower floating in a sea of clouds",
scale=3,
)
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=100,
value=50,
step=1,
scale=1,
)
generate_btn = gr.Button("Generate Images")
gr.Examples(
examples=[
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
],
inputs=[prompt],
)
with gr.Row():
original_output = gr.Image(label="Original Output")
ecodiff_output = gr.Image(label="EcoDiff Output")
gr.on(
triggers=[generate_btn.click, prompt.submit],
fn=generate_images,
inputs=[
prompt,
seed,
steps,
pipe,
pipe,
],
outputs=[original_output, ecodiff_output],
)
if __name__ == "__main__":
demo.launch()
|