Spaces:
Runtime error
Runtime error
File size: 3,678 Bytes
aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 7fb327a aaead26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import numpy as np
import random
from huggingface_hub import hf_hub_download
# import spaces #[uncomment to use ZeroGPU]
from diffusers import FluxPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use
torch_dtype = torch.bfloat16
pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
# load pruned model
pruned_pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to("cpu")
pruned_pipe.transformer = torch.load(
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"),
map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# @spaces.GPU
def generate_images(prompt, seed, steps):
pipe.to("cuda")
pruned_pipe.to("cuda")
# Run the model and return images directly
g_cpu = torch.Generator("cuda").manual_seed(seed)
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
g_cpu = torch.Generator("cuda").manual_seed(seed)
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
return original_image, ecodiff_image
examples = [
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
header = """
# 🌱 EcoDiff Pruned FLUX-Schnell (20% Pruning Ratio)
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
</div>
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(header)
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="A clock tower floating in a sea of clouds",
scale=3,
)
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=100,
value=5,
step=1,
scale=1,
)
generate_btn = gr.Button("Generate Images")
gr.Examples(
examples=[
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
],
inputs=[prompt],
)
with gr.Row():
original_output = gr.Image(label="Original Output")
ecodiff_output = gr.Image(label="EcoDiff Output")
gr.on(
triggers=[generate_btn.click, prompt.submit],
fn=generate_images,
inputs=[
prompt,
seed,
steps,
],
outputs=[original_output, ecodiff_output],
)
if __name__ == "__main__":
demo.launch()
|